diff --git a/mock_ack_frame_source_test.go b/mock_ack_frame_source_test.go new file mode 100644 index 000000000..1f3cd5785 --- /dev/null +++ b/mock_ack_frame_source_test.go @@ -0,0 +1,59 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: AckFrameSource) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockAckFrameSource is a mock of AckFrameSource interface +type MockAckFrameSource struct { + ctrl *gomock.Controller + recorder *MockAckFrameSourceMockRecorder +} + +// MockAckFrameSourceMockRecorder is the mock recorder for MockAckFrameSource +type MockAckFrameSourceMockRecorder struct { + mock *MockAckFrameSource +} + +// NewMockAckFrameSource creates a new mock instance +func NewMockAckFrameSource(ctrl *gomock.Controller) *MockAckFrameSource { + mock := &MockAckFrameSource{ctrl: ctrl} + mock.recorder = &MockAckFrameSourceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockAckFrameSource) EXPECT() *MockAckFrameSourceMockRecorder { + return m.recorder +} + +// GetAckFrame mocks base method +func (m *MockAckFrameSource) GetAckFrame() *wire.AckFrame { + ret := m.ctrl.Call(m, "GetAckFrame") + ret0, _ := ret[0].(*wire.AckFrame) + return ret0 +} + +// GetAckFrame indicates an expected call of GetAckFrame +func (mr *MockAckFrameSourceMockRecorder) GetAckFrame() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockAckFrameSource)(nil).GetAckFrame)) +} + +// GetStopWaitingFrame mocks base method +func (m *MockAckFrameSource) GetStopWaitingFrame(arg0 bool) *wire.StopWaitingFrame { + ret := m.ctrl.Call(m, "GetStopWaitingFrame", arg0) + ret0, _ := ret[0].(*wire.StopWaitingFrame) + return ret0 +} + +// GetStopWaitingFrame indicates an expected call of GetStopWaitingFrame +func (mr *MockAckFrameSourceMockRecorder) GetStopWaitingFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStopWaitingFrame", reflect.TypeOf((*MockAckFrameSource)(nil).GetStopWaitingFrame), arg0) +} diff --git a/mock_packer_test.go b/mock_packer_test.go new file mode 100644 index 000000000..a9e49a00d --- /dev/null +++ b/mock_packer_test.go @@ -0,0 +1,129 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: Packer) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + ackhandler "github.com/lucas-clemente/quic-go/internal/ackhandler" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockPacker is a mock of Packer interface +type MockPacker struct { + ctrl *gomock.Controller + recorder *MockPackerMockRecorder +} + +// MockPackerMockRecorder is the mock recorder for MockPacker +type MockPackerMockRecorder struct { + mock *MockPacker +} + +// NewMockPacker creates a new mock instance +func NewMockPacker(ctrl *gomock.Controller) *MockPacker { + mock := &MockPacker{ctrl: ctrl} + mock.recorder = &MockPackerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockPacker) EXPECT() *MockPackerMockRecorder { + return m.recorder +} + +// ChangeDestConnectionID mocks base method +func (m *MockPacker) ChangeDestConnectionID(arg0 protocol.ConnectionID) { + m.ctrl.Call(m, "ChangeDestConnectionID", arg0) +} + +// ChangeDestConnectionID indicates an expected call of ChangeDestConnectionID +func (mr *MockPackerMockRecorder) ChangeDestConnectionID(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeDestConnectionID", reflect.TypeOf((*MockPacker)(nil).ChangeDestConnectionID), arg0) +} + +// MaybePackAckPacket mocks base method +func (m *MockPacker) MaybePackAckPacket() (*packedPacket, error) { + ret := m.ctrl.Call(m, "MaybePackAckPacket") + ret0, _ := ret[0].(*packedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MaybePackAckPacket indicates an expected call of MaybePackAckPacket +func (mr *MockPackerMockRecorder) MaybePackAckPacket() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackAckPacket", reflect.TypeOf((*MockPacker)(nil).MaybePackAckPacket)) +} + +// PackConnectionClose mocks base method +func (m *MockPacker) PackConnectionClose(arg0 *wire.ConnectionCloseFrame) (*packedPacket, error) { + ret := m.ctrl.Call(m, "PackConnectionClose", arg0) + ret0, _ := ret[0].(*packedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PackConnectionClose indicates an expected call of PackConnectionClose +func (mr *MockPackerMockRecorder) PackConnectionClose(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackConnectionClose", reflect.TypeOf((*MockPacker)(nil).PackConnectionClose), arg0) +} + +// PackPacket mocks base method +func (m *MockPacker) PackPacket() (*packedPacket, error) { + ret := m.ctrl.Call(m, "PackPacket") + ret0, _ := ret[0].(*packedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PackPacket indicates an expected call of PackPacket +func (mr *MockPackerMockRecorder) PackPacket() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPacket", reflect.TypeOf((*MockPacker)(nil).PackPacket)) +} + +// PackRetransmission mocks base method +func (m *MockPacker) PackRetransmission(arg0 *ackhandler.Packet) ([]*packedPacket, error) { + ret := m.ctrl.Call(m, "PackRetransmission", arg0) + ret0, _ := ret[0].([]*packedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PackRetransmission indicates an expected call of PackRetransmission +func (mr *MockPackerMockRecorder) PackRetransmission(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackRetransmission", reflect.TypeOf((*MockPacker)(nil).PackRetransmission), arg0) +} + +// QueueControlFrame mocks base method +func (m *MockPacker) QueueControlFrame(arg0 wire.Frame) { + m.ctrl.Call(m, "QueueControlFrame", arg0) +} + +// QueueControlFrame indicates an expected call of QueueControlFrame +func (mr *MockPackerMockRecorder) QueueControlFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueControlFrame", reflect.TypeOf((*MockPacker)(nil).QueueControlFrame), arg0) +} + +// SetMaxPacketSize mocks base method +func (m *MockPacker) SetMaxPacketSize(arg0 protocol.ByteCount) { + m.ctrl.Call(m, "SetMaxPacketSize", arg0) +} + +// SetMaxPacketSize indicates an expected call of SetMaxPacketSize +func (mr *MockPackerMockRecorder) SetMaxPacketSize(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxPacketSize", reflect.TypeOf((*MockPacker)(nil).SetMaxPacketSize), arg0) +} + +// SetOmitConnectionID mocks base method +func (m *MockPacker) SetOmitConnectionID() { + m.ctrl.Call(m, "SetOmitConnectionID") +} + +// SetOmitConnectionID indicates an expected call of SetOmitConnectionID +func (mr *MockPackerMockRecorder) SetOmitConnectionID() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetOmitConnectionID", reflect.TypeOf((*MockPacker)(nil).SetOmitConnectionID)) +} diff --git a/mock_stream_frame_source_test.go b/mock_stream_frame_source_test.go index 9b3658052..fa29c02cd 100644 --- a/mock_stream_frame_source_test.go +++ b/mock_stream_frame_source_test.go @@ -35,6 +35,18 @@ func (m *MockStreamFrameSource) EXPECT() *MockStreamFrameSourceMockRecorder { return m.recorder } +// AppendStreamFrames mocks base method +func (m *MockStreamFrameSource) AppendStreamFrames(arg0 []wire.Frame, arg1 protocol.ByteCount) []wire.Frame { + ret := m.ctrl.Call(m, "AppendStreamFrames", arg0, arg1) + ret0, _ := ret[0].([]wire.Frame) + return ret0 +} + +// AppendStreamFrames indicates an expected call of AppendStreamFrames +func (mr *MockStreamFrameSourceMockRecorder) AppendStreamFrames(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AppendStreamFrames", reflect.TypeOf((*MockStreamFrameSource)(nil).AppendStreamFrames), arg0, arg1) +} + // HasCryptoStreamData mocks base method func (m *MockStreamFrameSource) HasCryptoStreamData() bool { ret := m.ctrl.Call(m, "HasCryptoStreamData") @@ -58,15 +70,3 @@ func (m *MockStreamFrameSource) PopCryptoStreamFrame(arg0 protocol.ByteCount) *w func (mr *MockStreamFrameSourceMockRecorder) PopCryptoStreamFrame(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopCryptoStreamFrame", reflect.TypeOf((*MockStreamFrameSource)(nil).PopCryptoStreamFrame), arg0) } - -// PopStreamFrames mocks base method -func (m *MockStreamFrameSource) PopStreamFrames(arg0 protocol.ByteCount) []*wire.StreamFrame { - ret := m.ctrl.Call(m, "PopStreamFrames", arg0) - ret0, _ := ret[0].([]*wire.StreamFrame) - return ret0 -} - -// PopStreamFrames indicates an expected call of PopStreamFrames -func (mr *MockStreamFrameSourceMockRecorder) PopStreamFrames(arg0 interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PopStreamFrames", reflect.TypeOf((*MockStreamFrameSource)(nil).PopStreamFrames), arg0) -} diff --git a/mockgen.go b/mockgen.go index 5b7cd4f0a..3667134d8 100644 --- a/mockgen.go +++ b/mockgen.go @@ -6,9 +6,11 @@ package quic //go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender" //go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter" //go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource" +//go:generate sh -c "./mockgen_private.sh quic mock_ack_frame_source_test.go github.com/lucas-clemente/quic-go ackFrameSource" //go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStream" //go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager" //go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker" +//go:generate sh -c "./mockgen_private.sh quic mock_packer_test.go github.com/lucas-clemente/quic-go packer" //go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD" //go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD" //go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner" diff --git a/packet_packer.go b/packet_packer.go index 9b2cccd50..538bc3e02 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -15,6 +15,19 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) +type packer interface { + QueueControlFrame(frame wire.Frame) + + PackPacket() (*packedPacket, error) + MaybePackAckPacket() (*packedPacket, error) + PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error) + PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error) + + SetOmitConnectionID() + ChangeDestConnectionID(protocol.ConnectionID) + SetMaxPacketSize(protocol.ByteCount) +} + type packedPacket struct { header *wire.Header raw []byte @@ -33,6 +46,23 @@ func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet { } } +func getMaxPacketSize(addr net.Addr) protocol.ByteCount { + maxSize := protocol.ByteCount(protocol.MinInitialPacketSize) + // If this is not a UDP address, we don't know anything about the MTU. + // Use the minimum size of an Initial packet as the max packet size. + if udpAddr, ok := addr.(*net.UDPAddr); ok { + // If ip is not an IPv4 address, To4 returns nil. + // Note that there might be some corner cases, where this is not correct. + // See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6. + if udpAddr.IP.To4() == nil { + maxSize = protocol.MaxPacketSizeIPv6 + } else { + maxSize = protocol.MaxPacketSizeIPv4 + } + } + return maxSize +} + type sealingManager interface { GetSealer() (protocol.EncryptionLevel, handshake.Sealer) GetSealerForCryptoStream() (protocol.EncryptionLevel, handshake.Sealer) @@ -42,7 +72,20 @@ type sealingManager interface { type streamFrameSource interface { HasCryptoStreamData() bool PopCryptoStreamFrame(protocol.ByteCount) *wire.StreamFrame - PopStreamFrames(protocol.ByteCount) []*wire.StreamFrame + AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame +} + +// sentAndReceivedPacketManager is only needed until STOP_WAITING is removed +type sentAndReceivedPacketManager struct { + ackhandler.SentPacketHandler + ackhandler.ReceivedPacketHandler +} + +var _ ackFrameSource = &sentAndReceivedPacketManager{} + +type ackFrameSource interface { + GetAckFrame() *wire.AckFrame + GetStopWaitingFrame(bool) *wire.StopWaitingFrame } type packetPacker struct { @@ -59,18 +102,19 @@ type packetPacker struct { packetNumberGenerator *packetNumberGenerator getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen streams streamFrameSource + acks ackFrameSource controlFrameMutex sync.Mutex controlFrames []wire.Frame - stopWaiting *wire.StopWaitingFrame - ackFrame *wire.AckFrame omitConnectionID bool maxPacketSize protocol.ByteCount hasSentPacket bool // has the packetPacker already sent a packet numNonRetransmittableAcks int } +var _ packer = &packetPacker{} + func newPacketPacker( destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, @@ -81,22 +125,10 @@ func newPacketPacker( divNonce []byte, cryptoSetup sealingManager, streamFramer streamFrameSource, + acks ackFrameSource, perspective protocol.Perspective, version protocol.VersionNumber, ) *packetPacker { - maxPacketSize := protocol.ByteCount(protocol.MinInitialPacketSize) - // If this is not a UDP address, we don't know anything about the MTU. - // Use the minimum size of an Initial packet as the max packet size. - if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { - // If ip is not an IPv4 address, To4 returns nil. - // Note that there might be some corner cases, where this is not correct. - // See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6. - if udpAddr.IP.To4() == nil { - maxPacketSize = protocol.MaxPacketSizeIPv6 - } else { - maxPacketSize = protocol.MaxPacketSizeIPv4 - } - } return &packetPacker{ cryptoSetup: cryptoSetup, divNonce: divNonce, @@ -106,9 +138,10 @@ func newPacketPacker( perspective: perspective, version: version, streams: streamFramer, + acks: acks, getPacketNumberLen: getPacketNumberLen, packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength), - maxPacketSize: maxPacketSize, + maxPacketSize: getMaxPacketSize(remoteAddr), } } @@ -126,20 +159,22 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac }, err } -func (p *packetPacker) PackAckPacket() (*packedPacket, error) { - if p.ackFrame == nil { - return nil, errors.New("packet packer BUG: no ack frame queued") +func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { + ack := p.acks.GetAckFrame() + if ack == nil { + return nil, nil } encLevel, sealer := p.cryptoSetup.GetSealer() header := p.getHeader(encLevel) - frames := []wire.Frame{p.ackFrame} - if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC - p.stopWaiting.PacketNumber = header.PacketNumber - p.stopWaiting.PacketNumberLen = header.PacketNumberLen - frames = append(frames, p.stopWaiting) - p.stopWaiting = nil + frames := []wire.Frame{ack} + // add a STOP_WAITING frame, when using gQUIC + if p.version.UsesStopWaitingFrames() { + if swf := p.acks.GetStopWaitingFrame(false); swf != nil { + swf.PacketNumber = header.PacketNumber + swf.PacketNumberLen = header.PacketNumberLen + frames = append(frames, swf) + } } - p.ackFrame = nil raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ header: header, @@ -171,9 +206,14 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP var packets []*packedPacket encLevel, sealer := p.cryptoSetup.GetSealer() + var swf *wire.StopWaitingFrame + // for gQUIC: add a STOP_WAITING for *every* retransmission + if p.version.UsesStopWaitingFrames() { + swf = p.acks.GetStopWaitingFrame(true) + } for len(controlFrames) > 0 || len(streamFrames) > 0 { var frames []wire.Frame - var payloadLength protocol.ByteCount + var length protocol.ByteCount header := p.getHeader(encLevel) headerLength, err := header.GetLength(p.version) @@ -182,28 +222,24 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP } maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength - // for gQUIC: add a STOP_WAITING for *every* retransmission if p.version.UsesStopWaitingFrames() { - if p.stopWaiting == nil { - return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame") - } - // create a new StopWaitingFrame, since we might need to send more than one packet as a retransmission - swf := &wire.StopWaitingFrame{ - LeastUnacked: p.stopWaiting.LeastUnacked, + // create a new STOP_WAIITNG Frame, since we might need to send more than one packet as a retransmission + stopWaitingFrame := &wire.StopWaitingFrame{ + LeastUnacked: swf.LeastUnacked, PacketNumber: header.PacketNumber, PacketNumberLen: header.PacketNumberLen, } - payloadLength += swf.Length(p.version) - frames = append(frames, swf) + length += stopWaitingFrame.Length(p.version) + frames = append(frames, stopWaitingFrame) } for len(controlFrames) > 0 { frame := controlFrames[0] - length := frame.Length(p.version) - if payloadLength+length > maxSize { + frameLen := frame.Length(p.version) + if length+frameLen > maxSize { break } - payloadLength += length + length += frameLen frames = append(frames, frame) controlFrames = controlFrames[1:] } @@ -218,12 +254,12 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP } else { maxSize += 2 } - for len(streamFrames) > 0 && payloadLength+protocol.MinStreamFrameSize < maxSize { + for len(streamFrames) > 0 && length+protocol.MinStreamFrameSize < maxSize { // TODO: optimize by setting DataLenPresent = false on all but the last STREAM frame frame := streamFrames[0] frameToAdd := frame - sf, err := frame.MaybeSplitOffFrame(maxSize-payloadLength, p.version) + sf, err := frame.MaybeSplitOffFrame(maxSize-length, p.version) if err != nil { return nil, err } @@ -232,7 +268,7 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP } else { streamFrames = streamFrames[1:] } - payloadLength += frameToAdd.Length(p.version) + length += frameToAdd.Length(p.version) frames = append(frames, frameToAdd) } if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok { @@ -249,7 +285,6 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP encryptionLevel: encLevel, }) } - p.stopWaiting = nil return packets, nil } @@ -267,13 +302,9 @@ func (p *packetPacker) packHandshakeRetransmission(packet *ackhandler.Packet) (* header.Type = packet.PacketType var frames []wire.Frame if p.version.UsesStopWaitingFrames() { // for gQUIC: pack a STOP_WAITING first - if p.stopWaiting == nil { - return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame") - } - swf := p.stopWaiting + swf := p.acks.GetStopWaitingFrame(true) swf.PacketNumber = header.PacketNumber swf.PacketNumberLen = header.PacketNumberLen - p.stopWaiting = nil frames = append([]wire.Frame{swf}, packet.Frames...) } else { frames = packet.Frames @@ -306,49 +337,37 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { if err != nil { return nil, err } - if p.stopWaiting != nil { - p.stopWaiting.PacketNumber = header.PacketNumber - p.stopWaiting.PacketNumberLen = header.PacketNumberLen - } maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength - payloadFrames, err := p.composeNextPacket(maxSize, p.canSendData(encLevel)) + frames, err := p.composeNextPacket(header, maxSize, p.canSendData(encLevel)) if err != nil { return nil, err } // Check if we have enough frames to send - if len(payloadFrames) == 0 { + if len(frames) == 0 { return nil, nil } - // Don't send out packets that only contain a StopWaitingFrame - if len(payloadFrames) == 1 && p.stopWaiting != nil { - return nil, nil - } - if p.ackFrame != nil { - // check if this packet only contains an ACK (and maybe a STOP_WAITING) - if len(payloadFrames) == 1 || (p.stopWaiting != nil && len(payloadFrames) == 2) { - if p.numNonRetransmittableAcks >= protocol.MaxNonRetransmittableAcks { - payloadFrames = append(payloadFrames, &wire.PingFrame{}) - p.numNonRetransmittableAcks = 0 - } else { - p.numNonRetransmittableAcks++ - } - } else { + // check if this packet only contains an ACK (and maybe a STOP_WAITING) + if !ackhandler.HasRetransmittableFrames(frames) { + if p.numNonRetransmittableAcks >= protocol.MaxNonRetransmittableAcks { + frames = append(frames, &wire.PingFrame{}) p.numNonRetransmittableAcks = 0 + } else { + p.numNonRetransmittableAcks++ } + } else { + p.numNonRetransmittableAcks = 0 } - p.stopWaiting = nil - p.ackFrame = nil - raw, err := p.writeAndSealPacket(header, payloadFrames, sealer) + raw, err := p.writeAndSealPacket(header, frames, sealer) if err != nil { return nil, err } return &packedPacket{ header: header, raw: raw, - frames: payloadFrames, + frames: frames, encryptionLevel: encLevel, }, nil } @@ -377,42 +396,44 @@ func (p *packetPacker) packCryptoPacket() (*packedPacket, error) { } func (p *packetPacker) composeNextPacket( + header *wire.Header, // only needed to fill in the STOP_WAITING frame maxFrameSize protocol.ByteCount, canSendStreamFrames bool, ) ([]wire.Frame, error) { - var payloadLength protocol.ByteCount - var payloadFrames []wire.Frame + var length protocol.ByteCount + var frames []wire.Frame // STOP_WAITING and ACK will always fit - if p.ackFrame != nil { // ACKs need to go first, so that the sentPacketHandler will recognize them - payloadFrames = append(payloadFrames, p.ackFrame) - l := p.ackFrame.Length(p.version) - payloadLength += l - } - if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC - payloadFrames = append(payloadFrames, p.stopWaiting) - payloadLength += p.stopWaiting.Length(p.version) + // ACKs need to go first, so that the sentPacketHandler will recognize them + if ack := p.acks.GetAckFrame(); ack != nil { + frames = append(frames, ack) + length += ack.Length(p.version) + // add a STOP_WAITING, for gQUIC + if p.version.UsesStopWaitingFrames() { + if swf := p.acks.GetStopWaitingFrame(false); swf != nil { + swf.PacketNumber = header.PacketNumber + swf.PacketNumberLen = header.PacketNumberLen + frames = append(frames, swf) + length += swf.Length(p.version) + } + } } p.controlFrameMutex.Lock() for len(p.controlFrames) > 0 { frame := p.controlFrames[len(p.controlFrames)-1] - length := frame.Length(p.version) - if payloadLength+length > maxFrameSize { + frameLen := frame.Length(p.version) + if length+frameLen > maxFrameSize { break } - payloadFrames = append(payloadFrames, frame) - payloadLength += length + frames = append(frames, frame) + length += frameLen p.controlFrames = p.controlFrames[:len(p.controlFrames)-1] } p.controlFrameMutex.Unlock() - if payloadLength > maxFrameSize { - return nil, fmt.Errorf("Packet Packer BUG: packet payload (%d) too large (%d)", payloadLength, maxFrameSize) - } - if !canSendStreamFrames { - return payloadFrames, nil + return frames, nil } // temporarily increase the maxFrameSize by the (minimum) length of the DataLen field @@ -426,28 +447,20 @@ func (p *packetPacker) composeNextPacket( maxFrameSize += 2 } - fs := p.streams.PopStreamFrames(maxFrameSize - payloadLength) - if len(fs) != 0 { - fs[len(fs)-1].DataLenPresent = false + frames = p.streams.AppendStreamFrames(frames, maxFrameSize-length) + if len(frames) > 0 { + lastFrame := frames[len(frames)-1] + if sf, ok := lastFrame.(*wire.StreamFrame); ok { + sf.DataLenPresent = false + } } - - for _, f := range fs { - payloadFrames = append(payloadFrames, f) - } - return payloadFrames, nil + return frames, nil } func (p *packetPacker) QueueControlFrame(frame wire.Frame) { - switch f := frame.(type) { - case *wire.StopWaitingFrame: - p.stopWaiting = f - case *wire.AckFrame: - p.ackFrame = f - default: - p.controlFrameMutex.Lock() - p.controlFrames = append(p.controlFrames, f) - p.controlFrameMutex.Unlock() - } + p.controlFrameMutex.Lock() + p.controlFrames = append(p.controlFrames, frame) + p.controlFrameMutex.Unlock() } func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header { @@ -494,7 +507,7 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header func (p *packetPacker) writeAndSealPacket( header *wire.Header, - payloadFrames []wire.Frame, + frames []wire.Frame, sealer handshake.Sealer, ) ([]byte, error) { raw := *getPacketBuffer() @@ -507,7 +520,7 @@ func (p *packetPacker) writeAndSealPacket( header.PayloadLen = protocol.ByteCount(protocol.MinInitialPacketSize) - headerLen } else { payloadLen := protocol.ByteCount(sealer.Overhead()) - for _, frame := range payloadFrames { + for _, frame := range frames { payloadLen += frame.Length(p.version) } header.PayloadLen = payloadLen @@ -521,12 +534,12 @@ func (p *packetPacker) writeAndSealPacket( // the Initial packet needs to be padded, so the last STREAM frame must have the data length present if header.Type == protocol.PacketTypeInitial { - lastFrame := payloadFrames[len(payloadFrames)-1] + lastFrame := frames[len(frames)-1] if sf, ok := lastFrame.(*wire.StreamFrame); ok { sf.DataLenPresent = true } } - for _, frame := range payloadFrames { + for _, frame := range frames { if err := frame.Write(buffer, p.version); err != nil { return nil, err } diff --git a/packet_packer_test.go b/packet_packer_test.go index ae40aa2e2..df7917750 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -60,6 +60,7 @@ var _ = Describe("Packet packer", func() { publicHeaderLen protocol.ByteCount maxFrameSize protocol.ByteCount mockStreamFramer *MockStreamFrameSource + mockAckFramer *MockAckFrameSource divNonce []byte token []byte ) @@ -73,11 +74,18 @@ var _ = Describe("Packet packer", func() { ExpectWithOffset(0, hdr.PayloadLen).To(BeEquivalentTo(r.Len())) } + expectAppendStreamFrames := func(frames ...wire.Frame) { + mockStreamFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, _ protocol.ByteCount) []wire.Frame { + return append(frames, fs...) + }) + } + BeforeEach(func() { version := versionGQUICFrames mockSender := NewMockStreamSender(mockCtrl) mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes() mockStreamFramer = NewMockStreamFrameSource(mockCtrl) + mockAckFramer = NewMockAckFrameSource(mockCtrl) divNonce = bytes.Repeat([]byte{'e'}, 32) token = []byte("initial token") @@ -91,6 +99,7 @@ var _ = Describe("Packet packer", func() { divNonce, &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, mockStreamFramer, + mockAckFramer, protocol.PerspectiveServer, version, ) @@ -102,31 +111,26 @@ var _ = Describe("Packet packer", func() { }) Context("determining the maximum packet size", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - It("uses the minimum initial size, if it can't determine if the remote address is IPv4 or IPv6", func() { - remoteAddr := &net.TCPAddr{} - packer = newPacketPacker(connID, connID, 1, nil, remoteAddr, nil, nil, nil, nil, protocol.PerspectiveServer, protocol.VersionWhatever) - Expect(packer.maxPacketSize).To(BeEquivalentTo(protocol.MinInitialPacketSize)) + Expect(getMaxPacketSize(&net.TCPAddr{})).To(BeEquivalentTo(protocol.MinInitialPacketSize)) }) It("uses the maximum IPv4 packet size, if the remote address is IPv4", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(11, 12, 13, 14), Port: 1337} - packer = newPacketPacker(connID, connID, 1, nil, remoteAddr, nil, nil, nil, nil, protocol.PerspectiveServer, protocol.VersionWhatever) - Expect(packer.maxPacketSize).To(BeEquivalentTo(protocol.MaxPacketSizeIPv4)) + addr := &net.UDPAddr{IP: net.IPv4(11, 12, 13, 14), Port: 1337} + Expect(getMaxPacketSize(addr)).To(BeEquivalentTo(protocol.MaxPacketSizeIPv4)) }) It("uses the maximum IPv6 packet size, if the remote address is IPv6", func() { ip := net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334") - remoteAddr := &net.UDPAddr{IP: ip, Port: 1337} - packer = newPacketPacker(connID, connID, 1, nil, remoteAddr, nil, nil, nil, nil, protocol.PerspectiveServer, protocol.VersionWhatever) - Expect(packer.maxPacketSize).To(BeEquivalentTo(protocol.MaxPacketSizeIPv6)) + addr := &net.UDPAddr{IP: ip, Port: 1337} + Expect(getMaxPacketSize(addr)).To(BeEquivalentTo(protocol.MaxPacketSizeIPv6)) }) }) It("returns nil when no packet is queued", func() { + mockAckFramer.EXPECT().GetAckFrame() mockStreamFramer.EXPECT().HasCryptoStreamData() - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) + mockStreamFramer.EXPECT().AppendStreamFrames(nil, gomock.Any()) p, err := packer.PackPacket() Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) @@ -134,11 +138,12 @@ var _ = Describe("Packet packer", func() { It("packs single packets", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() f := &wire.StreamFrame{ StreamID: 5, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, } - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f}) + expectAppendStreamFrames(f) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) @@ -150,10 +155,11 @@ var _ = Describe("Packet packer", func() { It("stores the encryption level a packet was sealed with", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{{ + mockAckFramer.EXPECT().GetAckFrame() + expectAppendStreamFrames(&wire.StreamFrame{ StreamID: 5, Data: []byte("foobar"), - }}) + }) packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -356,7 +362,8 @@ var _ = Describe("Packet packer", func() { It("packs only control frames", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) + mockAckFramer.EXPECT().GetAckFrame() + expectAppendStreamFrames() packer.QueueControlFrame(&wire.RstStreamFrame{}) packer.QueueControlFrame(&wire.MaxDataFrame{}) p, err := packer.PackPacket() @@ -368,7 +375,9 @@ var _ = Describe("Packet packer", func() { It("increases the packet number", func() { mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) + mockAckFramer.EXPECT().GetAckFrame().Times(2) + expectAppendStreamFrames() + expectAppendStreamFrames() packer.QueueControlFrame(&wire.RstStreamFrame{}) p1, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) @@ -380,49 +389,33 @@ var _ = Describe("Packet packer", func() { Expect(p2.header.PacketNumber).To(BeNumerically(">", p1.header.PacketNumber)) }) - It("packs a STOP_WAITING frame first", func() { + It("packs ACKs and STOP_WAITING frames first", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) - packer.packetNumberGenerator.next = 15 + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100}}} swf := &wire.StopWaitingFrame{LeastUnacked: 10} - packer.QueueControlFrame(&wire.RstStreamFrame{}) - packer.QueueControlFrame(swf) + mockAckFramer.EXPECT().GetAckFrame().Return(ack) + mockAckFramer.EXPECT().GetStopWaitingFrame(false).Return(swf) + expectAppendStreamFrames() + packer.packetNumberGenerator.next = 15 + cf := &wire.RstStreamFrame{} + packer.QueueControlFrame(cf) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.frames).To(HaveLen(2)) - Expect(p.frames[0]).To(Equal(swf)) + Expect(p.frames).To(Equal([]wire.Frame{ack, swf, cf})) }) It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) - packer.packetNumberGenerator.next = 0x1337 swf := &wire.StopWaitingFrame{LeastUnacked: 0x1337 - 0x100} - packer.QueueControlFrame(&wire.RstStreamFrame{}) - packer.QueueControlFrame(swf) + mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100}}}) + mockAckFramer.EXPECT().GetStopWaitingFrame(false).Return(swf) + expectAppendStreamFrames() + packer.packetNumberGenerator.next = 0x1337 p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.frames[0].(*wire.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) - }) - - It("does not pack a packet containing only a STOP_WAITING frame", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) - swf := &wire.StopWaitingFrame{LeastUnacked: 10} - packer.QueueControlFrame(swf) - p, err := packer.PackPacket() - Expect(p).To(BeNil()) - Expect(err).ToNot(HaveOccurred()) - }) - - It("packs a packet if it has queued control frames, but no new control frames", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) - packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}} - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) + Expect(p.frames).To(HaveLen(2)) + Expect(p.frames[1].(*wire.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) }) It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() { @@ -434,7 +427,8 @@ var _ = Describe("Packet packer", func() { Expect(p).To(BeNil()) }) - It("packs many control frames into 1 packets", func() { + It("packs many control frames into one packets", func() { + mockAckFramer.EXPECT().GetAckFrame().Times(2) f := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 1, Smallest: 1}}} b := &bytes.Buffer{} err := f.Write(b, packer.version) @@ -445,15 +439,16 @@ var _ = Describe("Packet packer", func() { controlFrames = append(controlFrames, f) } packer.controlFrames = controlFrames - payloadFrames, err := packer.composeNextPacket(maxFrameSize, false) + payloadFrames, err := packer.composeNextPacket(nil, maxFrameSize, false) Expect(err).ToNot(HaveOccurred()) Expect(payloadFrames).To(HaveLen(maxFramesPerPacket)) - payloadFrames, err = packer.composeNextPacket(maxFrameSize, false) + payloadFrames, err = packer.composeNextPacket(nil, maxFrameSize, false) Expect(err).ToNot(HaveOccurred()) Expect(payloadFrames).To(BeEmpty()) }) It("packs a lot of control frames into 2 packets if they don't fit into one", func() { + mockAckFramer.EXPECT().GetAckFrame().Times(2) blockedFrame := &wire.BlockedFrame{} maxFramesPerPacket := int(maxFrameSize) / int(blockedFrame.Length(packer.version)) var controlFrames []wire.Frame @@ -461,26 +456,27 @@ var _ = Describe("Packet packer", func() { controlFrames = append(controlFrames, blockedFrame) } packer.controlFrames = controlFrames - payloadFrames, err := packer.composeNextPacket(maxFrameSize, false) + payloadFrames, err := packer.composeNextPacket(nil, maxFrameSize, false) Expect(err).ToNot(HaveOccurred()) Expect(payloadFrames).To(HaveLen(maxFramesPerPacket)) - payloadFrames, err = packer.composeNextPacket(maxFrameSize, false) + payloadFrames, err = packer.composeNextPacket(nil, maxFrameSize, false) Expect(err).ToNot(HaveOccurred()) Expect(payloadFrames).To(HaveLen(10)) }) It("only increases the packet number when there is an actual packet to send", func() { + mockAckFramer.EXPECT().GetAckFrame().Times(2) mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) + expectAppendStreamFrames() packer.packetNumberGenerator.nextToSkip = 1000 p, err := packer.PackPacket() Expect(p).To(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(packer.packetNumberGenerator.Peek()).To(Equal(protocol.PacketNumber(1))) - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{{ + expectAppendStreamFrames(&wire.StreamFrame{ StreamID: 5, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, - }}) + }) p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) @@ -491,9 +487,10 @@ var _ = Describe("Packet packer", func() { Context("making ACK packets retransmittable", func() { sendMaxNumNonRetransmittableAcks := func() { mockStreamFramer.EXPECT().HasCryptoStreamData().Times(protocol.MaxNonRetransmittableAcks) - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(protocol.MaxNonRetransmittableAcks) for i := 0; i < protocol.MaxNonRetransmittableAcks; i++ { - packer.QueueControlFrame(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + mockAckFramer.EXPECT().GetStopWaitingFrame(false) + expectAppendStreamFrames() p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) @@ -503,15 +500,19 @@ var _ = Describe("Packet packer", func() { It("adds a PING frame when it's supposed to send a retransmittable packet", func() { sendMaxNumNonRetransmittableAcks() - mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) - packer.QueueControlFrame(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + mockAckFramer.EXPECT().GetStopWaitingFrame(false) + expectAppendStreamFrames() p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(ContainElement(&wire.PingFrame{})) // make sure the next packet doesn't contain another PING - packer.QueueControlFrame(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}) + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + mockAckFramer.EXPECT().GetStopWaitingFrame(false) + expectAppendStreamFrames() p, err = packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) @@ -520,12 +521,18 @@ var _ = Describe("Packet packer", func() { It("waits until there's something to send before adding a PING frame", func() { sendMaxNumNonRetransmittableAcks() - mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Times(2) + // nothing to send + mockStreamFramer.EXPECT().HasCryptoStreamData() + expectAppendStreamFrames() + mockAckFramer.EXPECT().GetAckFrame() p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) - packer.QueueControlFrame(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + // now add some frame to send + expectAppendStreamFrames() + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + mockAckFramer.EXPECT().GetStopWaitingFrame(false) p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(2)) @@ -535,58 +542,65 @@ var _ = Describe("Packet packer", func() { It("doesn't send a PING if it already sent another retransmittable frame", func() { sendMaxNumNonRetransmittableAcks() mockStreamFramer.EXPECT().HasCryptoStreamData() - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) + mockAckFramer.EXPECT().GetAckFrame() + expectAppendStreamFrames() packer.QueueControlFrame(&wire.MaxDataFrame{}) - packer.QueueControlFrame(&wire.StopWaitingFrame{}) p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(2)) Expect(p.frames).ToNot(ContainElement(&wire.PingFrame{})) }) }) Context("STREAM frame handling", func() { - It("does not splits a STREAM frame with maximum size, for gQUIC frames", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).DoAndReturn(func(maxSize protocol.ByteCount) []*wire.StreamFrame { + It("does not split a STREAM frame with maximum size, for gQUIC frames", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() + mockStreamFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(_ []wire.Frame, maxSize protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { f := &wire.StreamFrame{ Offset: 1, StreamID: 5, DataLenPresent: true, } f.Data = bytes.Repeat([]byte{'f'}, int(maxSize-f.Length(packer.version))) - return []*wire.StreamFrame{f} + return []wire.Frame{f}, f.Length(packer.version) }) - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) + mockStreamFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) Expect(p.raw).To(HaveLen(int(maxPacketSize))) Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) + // make sure there's nothing else to send + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) - It("does not splits a STREAM frame with maximum size, for IETF draft style frame", func() { + It("does not split a STREAM frame with maximum size, for IETF draft style frame", func() { packer.version = versionIETFFrames - mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).DoAndReturn(func(maxSize protocol.ByteCount) []*wire.StreamFrame { + mockAckFramer.EXPECT().GetAckFrame() + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockStreamFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(_ []wire.Frame, maxSize protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { f := &wire.StreamFrame{ Offset: 1, StreamID: 5, DataLenPresent: true, } f.Data = bytes.Repeat([]byte{'f'}, int(maxSize-f.Length(packer.version))) - return []*wire.StreamFrame{f} + return []wire.Frame{f}, f.Length(packer.version) }) - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) + mockStreamFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(1)) Expect(p.raw).To(HaveLen(int(maxPacketSize))) Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) + // make sure there's nothing else to send + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) @@ -609,21 +623,23 @@ var _ = Describe("Packet packer", func() { DataLenPresent: true, } mockStreamFramer.EXPECT().HasCryptoStreamData() - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f1, f2, f3}) + mockAckFramer.EXPECT().GetAckFrame() + expectAppendStreamFrames(f1, f2, f3) p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(3)) Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1"))) - Expect(p.frames[1].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) - Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 3"))) Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) + Expect(p.frames[1].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) + Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 3"))) Expect(p.frames[2].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) }) It("refuses to send unencrypted stream data on a data stream", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() // don't expect a call to mockStreamFramer.PopStreamFrames packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted p, err := packer.PackPacket() @@ -632,12 +648,13 @@ var _ = Describe("Packet packer", func() { }) It("sends non forward-secure data as the client", func() { + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() f := &wire.StreamFrame{ StreamID: 5, Data: []byte("foobar"), } - mockStreamFramer.EXPECT().HasCryptoStreamData() - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).Return([]*wire.StreamFrame{f}) + expectAppendStreamFrames(f) packer.perspective = protocol.PerspectiveClient packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure p, err := packer.PackPacket() @@ -648,6 +665,7 @@ var _ = Describe("Packet packer", func() { It("does not send non forward-secure data as the server", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() // don't expect a call to mockStreamFramer.PopStreamFrames packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure p, err := packer.PackPacket() @@ -706,10 +724,11 @@ var _ = Describe("Packet packer", func() { It("does not pack STREAM frames if not allowed", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 10, Smallest: 1}}} + mockAckFramer.EXPECT().GetAckFrame().Return(ack) + mockAckFramer.EXPECT().GetStopWaitingFrame(false) // don't expect a call to mockStreamFramer.PopStreamFrames packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 10, Smallest: 1}}} - packer.QueueControlFrame(ack) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(Equal([]wire.Frame{ack})) @@ -718,37 +737,25 @@ var _ = Describe("Packet packer", func() { It("packs a single ACK", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 42, Smallest: 1}}} - packer.QueueControlFrame(ack) + mockAckFramer.EXPECT().GetAckFrame().Return(ack) + mockAckFramer.EXPECT().GetStopWaitingFrame(false) + expectAppendStreamFrames() p, err := packer.PackPacket() Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.frames[0]).To(Equal(ack)) }) - It("does not return nil if we only have a single ACK but request it to be sent", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()) - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} - packer.QueueControlFrame(ack) - p, err := packer.PackPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(p).ToNot(BeNil()) - }) - Context("retransmitting of handshake packets", func() { - swf := &wire.StopWaitingFrame{LeastUnacked: 1} sf := &wire.StreamFrame{ StreamID: 1, Data: []byte("foobar"), } - BeforeEach(func() { - packer.QueueControlFrame(swf) - }) - It("packs a retransmission for a packet sent with no encryption", func() { + swf := &wire.StopWaitingFrame{LeastUnacked: 1} + mockAckFramer.EXPECT().GetStopWaitingFrame(true).Return(swf) packet := &ackhandler.Packet{ PacketType: protocol.PacketTypeHandshake, EncryptionLevel: protocol.EncryptionUnencrypted, @@ -775,7 +782,9 @@ var _ = Describe("Packet packer", func() { Expect(p[0].encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) }) - It("packs a retransmission for a packet sent with initial encryption", func() { + It("packs a retransmission for a packet sent with secure encryption", func() { + swf := &wire.StopWaitingFrame{LeastUnacked: 1} + mockAckFramer.EXPECT().GetStopWaitingFrame(true).Return(swf) packet := &ackhandler.Packet{ EncryptionLevel: protocol.EncryptionSecure, Frames: []wire.Frame{sf}, @@ -785,25 +794,16 @@ var _ = Describe("Packet packer", func() { Expect(p).To(HaveLen(1)) Expect(p[0].frames).To(Equal([]wire.Frame{swf, sf})) Expect(p[0].encryptionLevel).To(Equal(protocol.EncryptionSecure)) - // a packet sent by the server with initial encryption contains the SHLO + // a packet sent by the server with secure encryption contains the SHLO // it needs to have a diversification nonce Expect(p[0].raw).To(ContainSubstring(string(divNonce))) }) - It("includes the diversification nonce on packets sent with initial encryption", func() { - packet := &ackhandler.Packet{ - EncryptionLevel: protocol.EncryptionSecure, - Frames: []wire.Frame{sf}, - } - p, err := packer.PackRetransmission(packet) - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(HaveLen(1)) - Expect(p[0].encryptionLevel).To(Equal(protocol.EncryptionSecure)) - }) - // this should never happen, since non forward-secure packets are limited to a size smaller than MaxPacketSize, such that it is always possible to retransmit them without splitting the StreamFrame // (note that the retransmitted packet needs to have enough space for the StopWaitingFrame) It("refuses to send a packet larger than MaxPacketSize", func() { + swf := &wire.StopWaitingFrame{LeastUnacked: 1} + mockAckFramer.EXPECT().GetStopWaitingFrame(true).Return(swf) packet := &ackhandler.Packet{ EncryptionLevel: protocol.EncryptionSecure, Frames: []wire.Frame{ @@ -870,23 +870,13 @@ var _ = Describe("Packet packer", func() { Expect(p[0].header.Type).To(Equal(protocol.PacketTypeInitial)) Expect(p[0].header.Token).To(Equal(token)) }) - - It("refuses to retransmit packets without a STOP_WAITING Frame", func() { - packer.stopWaiting = nil - _, err := packer.PackRetransmission(&ackhandler.Packet{ - EncryptionLevel: protocol.EncryptionSecure, - }) - Expect(err).To(MatchError("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame")) - }) }) Context("retransmission of forward-secure packets", func() { - BeforeEach(func() { - packer.packetNumberGenerator.next = 15 - packer.stopWaiting = &wire.StopWaitingFrame{LeastUnacked: 7} - }) - It("retransmits a small packet", func() { + swf := &wire.StopWaitingFrame{LeastUnacked: 7} + packer.packetNumberGenerator.next = 10 + mockAckFramer.EXPECT().GetStopWaitingFrame(true).Return(swf) frames := []wire.Frame{ &wire.MaxDataFrame{ByteOffset: 0x1234}, &wire.StreamFrame{StreamID: 42, Data: []byte("foobar")}, @@ -907,15 +897,6 @@ var _ = Describe("Packet packer", func() { Expect(p.frames[1:]).To(Equal(frames)) }) - It("refuses to retransmit packets without a STOP_WAITING Frame", func() { - packer.stopWaiting = nil - _, err := packer.PackRetransmission(&ackhandler.Packet{ - EncryptionLevel: protocol.EncryptionForwardSecure, - Frames: []wire.Frame{&wire.MaxDataFrame{ByteOffset: 0x1234}}, - }) - Expect(err).To(MatchError("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame")) - }) - It("packs two packets for retransmission if the original packet contained many control frames", func() { var frames []wire.Frame var totalLen protocol.ByteCount @@ -925,6 +906,9 @@ var _ = Describe("Packet packer", func() { frames = append(frames, f) totalLen += f.Length(packer.version) } + packer.packetNumberGenerator.next = 10 + swf := &wire.StopWaitingFrame{LeastUnacked: 7} + mockAckFramer.EXPECT().GetStopWaitingFrame(true).Return(swf) packets, err := packer.PackRetransmission(&ackhandler.Packet{ EncryptionLevel: protocol.EncryptionForwardSecure, Frames: frames, @@ -942,6 +926,8 @@ var _ = Describe("Packet packer", func() { }) It("splits a STREAM frame that doesn't fit", func() { + swf := &wire.StopWaitingFrame{} + mockAckFramer.EXPECT().GetStopWaitingFrame(true).Return(swf) packets, err := packer.PackRetransmission(&ackhandler.Packet{ EncryptionLevel: protocol.EncryptionForwardSecure, Frames: []wire.Frame{&wire.StreamFrame{ @@ -969,6 +955,8 @@ var _ = Describe("Packet packer", func() { }) It("packs two packets for retransmission if the original packet contained many STREAM frames", func() { + swf := &wire.StopWaitingFrame{} + mockAckFramer.EXPECT().GetStopWaitingFrame(true).Return(swf) var frames []wire.Frame var totalLen protocol.ByteCount // pack a bunch of control frames, such that the packet is way bigger than a single packet @@ -998,6 +986,8 @@ var _ = Describe("Packet packer", func() { }) It("correctly sets the DataLenPresent on STREAM frames", func() { + swf := &wire.StopWaitingFrame{} + mockAckFramer.EXPECT().GetStopWaitingFrame(true).Return(swf) frames := []wire.Frame{ &wire.StreamFrame{StreamID: 4, Data: []byte("foobar"), DataLenPresent: true}, &wire.StreamFrame{StreamID: 5, Data: []byte("barfoo")}, @@ -1022,34 +1012,43 @@ var _ = Describe("Packet packer", func() { }) Context("packing ACK packets", func() { - It("packs ACK packets", func() { - packer.QueueControlFrame(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}}) - p, err := packer.PackAckPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0]).To(BeAssignableToTypeOf(&wire.AckFrame{})) - ack := p.frames[0].(*wire.AckFrame) - Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(10))) + It("doesn't pack a packet if there's no ACK to send", func() { + mockAckFramer.EXPECT().GetAckFrame() + p, err := packer.MaybePackAckPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) }) - It("packs ACK packets with STOP_WAITING frames", func() { - packer.QueueControlFrame(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}}) - packer.QueueControlFrame(&wire.StopWaitingFrame{}) - p, err := packer.PackAckPacket() + It("packs ACK packets", func() { + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} + swf := &wire.StopWaitingFrame{} + mockAckFramer.EXPECT().GetAckFrame().Return(ack) + mockAckFramer.EXPECT().GetStopWaitingFrame(false).Return(swf) + p, err := packer.MaybePackAckPacket() Expect(err).NotTo(HaveOccurred()) - Expect(p.frames).To(HaveLen(2)) - Expect(p.frames[0]).To(BeAssignableToTypeOf(&wire.AckFrame{})) - Expect(p.frames[1]).To(Equal(&wire.StopWaitingFrame{PacketNumber: 1, PacketNumberLen: 2})) + Expect(p.frames).To(Equal([]wire.Frame{ack, swf})) + }) + + It("doesn't add a STOP_WAITING frame for IETF QUIC", func() { + packer.version = versionIETFFrames + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} + mockAckFramer.EXPECT().GetAckFrame().Return(ack) + p, err := packer.MaybePackAckPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(p.frames).To(Equal([]wire.Frame{ack})) }) }) Context("max packet size", func() { It("sets the maximum packet size", func() { + mockAckFramer.EXPECT().GetAckFrame().Times(2) for i := 0; i < 10*int(maxPacketSize); i++ { packer.QueueControlFrame(&wire.PingFrame{}) } mockStreamFramer.EXPECT().HasCryptoStreamData().AnyTimes() - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).AnyTimes() + mockStreamFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, _ protocol.ByteCount) []wire.Frame { + return fs + }).AnyTimes() p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.raw).To(HaveLen(int(maxPacketSize))) @@ -1061,11 +1060,14 @@ var _ = Describe("Packet packer", func() { }) It("doesn't increase the max packet size", func() { + mockAckFramer.EXPECT().GetAckFrame().Times(2) for i := 0; i < 10*int(maxPacketSize); i++ { packer.QueueControlFrame(&wire.PingFrame{}) } mockStreamFramer.EXPECT().HasCryptoStreamData().AnyTimes() - mockStreamFramer.EXPECT().PopStreamFrames(gomock.Any()).AnyTimes() + mockStreamFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(fs []wire.Frame, _ protocol.ByteCount) []wire.Frame { + return fs + }).AnyTimes() p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.raw).To(HaveLen(int(maxPacketSize))) diff --git a/session.go b/session.go index 478c1d280..b6087cf25 100644 --- a/session.go +++ b/session.go @@ -97,7 +97,7 @@ type session struct { connFlowController flowcontrol.ConnectionFlowController unpacker unpacker - packer *packetPacker + packer packer cryptoStreamHandler cryptoStreamHandler @@ -216,6 +216,7 @@ func newSession( divNonce, cs, s.streamFramer, + sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, s.perspective, s.version, ) @@ -289,6 +290,7 @@ var newClientSession = func( nil, // no diversification nonce cs, s.streamFramer, + sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, s.perspective, s.version, ) @@ -344,6 +346,7 @@ func newTLSServerSession( nil, // no diversification nonce cs, s.streamFramer, + sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, s.perspective, s.version, ) @@ -408,6 +411,7 @@ var newTLSClientSession = func( nil, // no diversification nonce cs, s.streamFramer, + sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, s.perspective, s.version, ) @@ -417,6 +421,7 @@ var newTLSClientSession = func( func (s *session) preSetup() { s.rttStats = &congestion.RTTStats{} s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger, s.version) + s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version) s.connFlowController = flowcontrol.NewConnectionFlowController( protocol.ReceiveConnectionFlowControlWindow, protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), @@ -439,7 +444,6 @@ func (s *session) postSetup() error { s.lastNetworkActivityTime = now s.sessionCreationTime = now - s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version) s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.packer.QueueControlFrame) return nil } @@ -988,21 +992,13 @@ sendLoop: } func (s *session) maybeSendAckOnlyPacket() error { - ack := s.receivedPacketHandler.GetAckFrame() - if ack == nil { - return nil - } - s.packer.QueueControlFrame(ack) - - if s.version.UsesStopWaitingFrames() { // for gQUIC, maybe add a STOP_WAITING - if swf := s.sentPacketHandler.GetStopWaitingFrame(false); swf != nil { - s.packer.QueueControlFrame(swf) - } - } - packet, err := s.packer.PackAckPacket() + packet, err := s.packer.MaybePackAckPacket() if err != nil { return err } + if packet == nil { + return nil + } s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket()) return s.sendPackedPacket(packet) } @@ -1033,9 +1029,6 @@ func (s *session) maybeSendRetransmission() (bool, error) { s.logger.Debugf("Dequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) } - if s.version.UsesStopWaitingFrames() { - s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true)) - } packets, err := s.packer.PackRetransmission(retransmitPacket) if err != nil { return false, err @@ -1060,9 +1053,6 @@ func (s *session) sendProbePacket() error { } s.logger.Debugf("Sending a retransmission for %#x as a probe packet.", p.PacketNumber) - if s.version.UsesStopWaitingFrames() { - s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true)) - } packets, err := s.packer.PackRetransmission(p) if err != nil { return err @@ -1086,15 +1076,6 @@ func (s *session) sendPacket() (bool, error) { } s.windowUpdateQueue.QueueAll() - if ack := s.receivedPacketHandler.GetAckFrame(); ack != nil { - s.packer.QueueControlFrame(ack) - if s.version.UsesStopWaitingFrames() { - if swf := s.sentPacketHandler.GetStopWaitingFrame(false); swf != nil { - s.packer.QueueControlFrame(swf) - } - } - } - packet, err := s.packer.PackPacket() if err != nil || packet == nil { return false, err diff --git a/session_test.go b/session_test.go index 26cc46507..fa6fb390c 100644 --- a/session_test.go +++ b/session_test.go @@ -73,6 +73,7 @@ var _ = Describe("Session", func() { mconn *mockConnection cryptoSetup *mockCryptoSetup streamManager *MockStreamManager + packer *MockPacker handshakeChan chan<- struct{} ) @@ -121,6 +122,8 @@ var _ = Describe("Session", func() { sess = pSess.(*session) streamManager = NewMockStreamManager(mockCtrl) sess.streamsMap = streamManager + packer = NewMockPacker(mockCtrl) + sess.packer = packer }) AfterEach(func() { @@ -442,11 +445,10 @@ var _ = Describe("Session", func() { }) It("handles PATH_CHALLENGE frames", func() { - err := sess.handleFrames([]wire.Frame{&wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}}, protocol.EncryptionUnspecified) + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + packer.EXPECT().QueueControlFrame(&wire.PathResponseFrame{Data: data}) + err := sess.handleFrames([]wire.Frame{&wire.PathChallengeFrame{Data: data}}, protocol.EncryptionUnspecified) Expect(err).ToNot(HaveOccurred()) - Expect(sess.packer.controlFrames).To(HaveLen(1)) - Expect(sess.packer.controlFrames[0]).To(BeAssignableToTypeOf(&wire.PathResponseFrame{})) - Expect(sess.packer.controlFrames[0].(*wire.PathResponseFrame).Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) }) It("handles BLOCKED frames", func() { @@ -515,21 +517,20 @@ var _ = Describe("Session", func() { It("shuts down without error", func() { streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, "")) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close() + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{raw: []byte("connection close")}, nil) + Expect(sess.Close()).To(Succeed()) Eventually(areSessionsRunning).Should(BeFalse()) Expect(mconn.written).To(HaveLen(1)) - buf := &bytes.Buffer{} - err := (&wire.ConnectionCloseFrame{ErrorCode: qerr.PeerGoingAway}).Write(buf, sess.version) - Expect(err).ToNot(HaveOccurred()) - Expect(mconn.written).To(Receive(ContainSubstring(buf.String()))) + Expect(mconn.written).To(Receive(ContainSubstring("connection close"))) Expect(sess.Context().Done()).To(BeClosed()) }) It("only closes once", func() { streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, "")) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close() - sess.Close() + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + Expect(sess.Close()).To(Succeed()) + Expect(sess.Close()).To(Succeed()) Eventually(areSessionsRunning).Should(BeFalse()) Expect(mconn.written).To(HaveLen(1)) Expect(sess.Context().Done()).To(BeClosed()) @@ -539,6 +540,7 @@ var _ = Describe("Session", func() { testErr := errors.New("test error") streamManager.EXPECT().CloseWithError(qerr.Error(0x1337, testErr.Error())) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.CloseWithError(0x1337, testErr) Eventually(areSessionsRunning).Should(BeFalse()) Expect(sess.Context().Done()).To(BeClosed()) @@ -564,6 +566,7 @@ var _ = Describe("Session", func() { It("cancels the context when the run loop exists", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) returned := make(chan struct{}) go func() { defer GinkgoRecover() @@ -625,6 +628,7 @@ var _ = Describe("Session", func() { testErr := errors.New("unpack error") unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr) streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) hdr.PacketNumber = 5 done := make(chan struct{}) go func() { @@ -702,53 +706,55 @@ var _ = Describe("Session", func() { }) Context("sending packets", func() { - BeforeEach(func() { - sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends - }) + getPacket := func(pn protocol.PacketNumber) *packedPacket { + data := *getPacketBuffer() + data = append(data, []byte("foobar")...) + return &packedPacket{ + raw: data, + header: &wire.Header{PacketNumber: pn}, + } + } - It("sends ACK frames", func() { - packetNumber := protocol.PacketNumber(0x035e) - err := sess.receivedPacketHandler.ReceivedPacket(packetNumber, time.Now(), true) + It("sends packets", func() { + packer.EXPECT().PackPacket().Return(getPacket(1), nil) + err := sess.receivedPacketHandler.ReceivedPacket(0x035e, time.Now(), true) Expect(err).ToNot(HaveOccurred()) sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(sent).To(BeTrue()) - Expect(mconn.written).To(HaveLen(1)) - Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x03, 0x5e})))) }) - It("adds MAX_STREAM_DATA frames", func() { - sess.windowUpdateQueue.callback(&wire.MaxStreamDataFrame{ - StreamID: 2, - ByteOffset: 20, - }) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.Frames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 2, ByteOffset: 20})) - }) - sess.sentPacketHandler = sph + It("doesn't send packets if there's nothing to send", func() { + packer.EXPECT().PackPacket().Return(getPacket(2), nil) + err := sess.receivedPacketHandler.ReceivedPacket(0x035e, time.Now(), true) + Expect(err).ToNot(HaveOccurred()) sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(sent).To(BeTrue()) }) + It("sends ACK only packets", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetAlarmTimeout().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAck) + sph.EXPECT().ShouldSendNumPackets().Return(1000) + packer.EXPECT().MaybePackAckPacket() + sess.sentPacketHandler = sph + Expect(sess.sendPackets()).To(Succeed()) + }) + It("adds a BLOCKED frame when it is connection-level flow control blocked", func() { fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) + packer.EXPECT().PackPacket().Return(getPacket(1), nil) sess.connFlowController = fc - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.Frames).To(Equal([]wire.Frame{ - &wire.BlockedFrame{Offset: 1337}, - })) - }) - sess.sentPacketHandler = sph + packer.EXPECT().QueueControlFrame(&wire.BlockedFrame{Offset: 1337}) sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(sent).To(BeTrue()) }) - It("sends public reset", func() { + It("sends PUBLIC_RESET", func() { err := sess.sendPublicReset(1) Expect(err).NotTo(HaveOccurred()) Expect(mconn.written).To(HaveLen(1)) @@ -782,57 +788,80 @@ var _ = Describe("Session", func() { }) It("sends a retransmission and a regular packet in the same run", func() { + packetToRetransmit := &ackhandler.Packet{ + PacketNumber: 10, + PacketType: protocol.PacketTypeHandshake, + } + retransmittedPacket := getPacket(123) + newPacket := getPacket(234) sess.windowUpdateQueue.callback(&wire.MaxDataFrame{}) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() - sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ - PacketNumber: 10, - PacketType: protocol.PacketTypeHandshake, - }) + sph.EXPECT().DequeuePacketForRetransmission().Return(packetToRetransmit) sph.EXPECT().SendMode().Return(ackhandler.SendRetransmission) sph.EXPECT().SendMode().Return(ackhandler.SendAny) sph.EXPECT().ShouldSendNumPackets().Return(2) sph.EXPECT().TimeUntilSend() - sph.EXPECT().GetStopWaitingFrame(gomock.Any()).Return(&wire.StopWaitingFrame{}) gomock.InOrder( + packer.EXPECT().PackRetransmission(packetToRetransmit).Return([]*packedPacket{retransmittedPacket}, nil), sph.EXPECT().SentPacketsAsRetransmission(gomock.Any(), protocol.PacketNumber(10)).Do(func(packets []*ackhandler.Packet, _ protocol.PacketNumber) { Expect(packets).To(HaveLen(1)) - Expect(len(packets[0].Frames)).To(BeNumerically(">", 0)) - Expect(packets[0].Frames[0]).To(BeAssignableToTypeOf(&wire.StopWaitingFrame{})) - Expect(packets[0].SendTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) + Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(123))) }), + packer.EXPECT().PackPacket().Return(newPacket, nil), sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.Frames).To(HaveLen(1)) - Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.MaxDataFrame{})) - Expect(p.SendTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) + Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(234))) }), ) sess.sentPacketHandler = sph - err := sess.sendPackets() - Expect(err).ToNot(HaveOccurred()) + Expect(sess.sendPackets()).To(Succeed()) + }) + + It("sends multiple packets, if the retransmission is split", func() { + sess.version = versionIETFFrames + packet := &ackhandler.Packet{ + PacketNumber: 42, + Frames: []wire.Frame{&wire.StreamFrame{ + StreamID: 0x5, + Data: []byte("foobar"), + }}, + EncryptionLevel: protocol.EncryptionForwardSecure, + } + retransmissions := []*packedPacket{getPacket(1337), getPacket(1338)} + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().DequeuePacketForRetransmission().Return(packet) + packer.EXPECT().PackRetransmission(packet).Return(retransmissions, nil) + sph.EXPECT().SentPacketsAsRetransmission(gomock.Any(), protocol.PacketNumber(42)).Do(func(packets []*ackhandler.Packet, _ protocol.PacketNumber) { + Expect(packets).To(HaveLen(2)) + Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(1337))) + Expect(packets[1].PacketNumber).To(Equal(protocol.PacketNumber(1338))) + }) + sess.sentPacketHandler = sph + sent, err := sess.maybeSendRetransmission() + Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) + Expect(mconn.written).To(HaveLen(2)) }) It("sends a probe packet", func() { - f := &wire.MaxDataFrame{ByteOffset: 1337} + packetToRetransmit := &ackhandler.Packet{ + PacketNumber: 0x42, + PacketType: protocol.PacketTypeHandshake, + } + retransmittedPacket := getPacket(123) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().TimeUntilSend() sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendTLP) sph.EXPECT().ShouldSendNumPackets().Return(1) - sph.EXPECT().DequeueProbePacket().Return(&ackhandler.Packet{ - PacketNumber: 0x42, - Frames: []wire.Frame{f}, - }, nil) - sph.EXPECT().GetStopWaitingFrame(true).Return(&wire.StopWaitingFrame{}) + sph.EXPECT().DequeueProbePacket().Return(packetToRetransmit, nil) + packer.EXPECT().PackRetransmission(packetToRetransmit).Return([]*packedPacket{retransmittedPacket}, nil) sph.EXPECT().SentPacketsAsRetransmission(gomock.Any(), protocol.PacketNumber(0x42)).Do(func(packets []*ackhandler.Packet, _ protocol.PacketNumber) { Expect(packets).To(HaveLen(1)) - p := packets[0] - Expect(p.Frames).To(HaveLen(2)) - Expect(p.Frames[1]).To(Equal(f)) + Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(123))) }) sess.sentPacketHandler = sph - err := sess.sendPackets() - Expect(err).ToNot(HaveOccurred()) + Expect(sess.sendPackets()).To(Succeed()) }) It("doesn't send when the SentPacketHandler doesn't allow it", func() { @@ -842,420 +871,197 @@ var _ = Describe("Session", func() { err := sess.sendPackets() Expect(err).ToNot(HaveOccurred()) }) - }) - Context("packet pacing", func() { - var sph *mockackhandler.MockSentPacketHandler + Context("packet pacing", func() { + var sph *mockackhandler.MockSentPacketHandler - BeforeEach(func() { - sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetAlarmTimeout().AnyTimes() - sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() - sph.EXPECT().DequeuePacketForRetransmission().AnyTimes() - sess.sentPacketHandler = sph - sess.packer.hasSentPacket = true - streamManager.EXPECT().CloseWithError(gomock.Any()) - }) - - It("sends multiple packets one by one immediately", func() { - sph.EXPECT().SentPacket(gomock.Any()).Times(2) - sph.EXPECT().ShouldSendNumPackets().Return(1).Times(2) - sph.EXPECT().TimeUntilSend().Return(time.Now()).Times(2) - sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) - sph.EXPECT().SendMode().Return(ackhandler.SendAny).Do(func() { - // make sure there's something to send - sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 1}) - }).Times(2) // allow 2 packets... - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - sess.run() - close(done) - }() - sess.scheduleSending() - Eventually(mconn.written).Should(HaveLen(2)) - Consistently(mconn.written).Should(HaveLen(2)) - // make the go routine return - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close() - Eventually(done).Should(BeClosed()) - }) - - // when becoming congestion limited, at some point the SendMode will change from SendAny to SendAck - // we shouldn't send the ACK in the same run - It("doesn't send an ACK right after becoming congestion limited", func() { - sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 1}) - sph.EXPECT().SentPacket(gomock.Any()) - sph.EXPECT().ShouldSendNumPackets().Return(1000) - sph.EXPECT().TimeUntilSend().Return(time.Now()) - sph.EXPECT().SendMode().Return(ackhandler.SendAny) - sph.EXPECT().SendMode().Return(ackhandler.SendAck) - rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) - rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour)).Times(2) - rph.EXPECT().GetAckFrame() - sess.receivedPacketHandler = rph - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - sess.run() - close(done) - }() - sess.scheduleSending() - Eventually(mconn.written).Should(HaveLen(1)) - Consistently(mconn.written).Should(HaveLen(1)) - // make the go routine return - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close() - Eventually(done).Should(BeClosed()) - }) - - It("paces packets", func() { - pacingDelay := scaleDuration(100 * time.Millisecond) - sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 1}) - sph.EXPECT().SentPacket(gomock.Any()).Times(2) - sph.EXPECT().TimeUntilSend().Return(time.Now().Add(-time.Minute)) // send one packet immediately - sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)) // send one - sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) - sph.EXPECT().ShouldSendNumPackets().Times(2).Return(1) - sph.EXPECT().SendMode().Return(ackhandler.SendAny).Do(func() { // after sending the first packet - // make sure there's something to send - sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 2}) - }).AnyTimes() - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - sess.run() - close(done) - }() - sess.scheduleSending() - Eventually(mconn.written).Should(HaveLen(1)) - Consistently(mconn.written, pacingDelay/2).Should(HaveLen(1)) - Eventually(mconn.written, 2*pacingDelay).Should(HaveLen(2)) - // make the go routine return - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close() - Eventually(done).Should(BeClosed()) - }) - - It("sends multiple packets at once", func() { - sph.EXPECT().SentPacket(gomock.Any()).Times(3) - sph.EXPECT().ShouldSendNumPackets().Return(3) - sph.EXPECT().TimeUntilSend().Return(time.Now()) - sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) - sph.EXPECT().SendMode().Return(ackhandler.SendAny).Do(func() { - // make sure there's something to send - sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 1}) - }).Times(3) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - sess.run() - close(done) - }() - sess.scheduleSending() - Eventually(mconn.written).Should(HaveLen(3)) - // make the go routine return - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close() - Eventually(done).Should(BeClosed()) - }) - - It("doesn't set a pacing timer when there is no data to send", func() { - sph.EXPECT().TimeUntilSend().Return(time.Now()) - sph.EXPECT().ShouldSendNumPackets().Return(1) - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - sess.run() - close(done) - }() - sess.scheduleSending() // no packet will get sent - Consistently(mconn.written).ShouldNot(Receive()) - // queue a frame, and expect that it won't be sent - sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 1}) - Consistently(mconn.written).ShouldNot(Receive()) - // make the go routine return - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close() - Eventually(done).Should(BeClosed()) - }) - }) - - Context("sending ACK only packets", func() { - It("doesn't do anything if there's no ACK to be sent", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sess.sentPacketHandler = sph - err := sess.maybeSendAckOnlyPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(mconn.written).To(BeEmpty()) - }) - - It("sends ACK only packets", func() { - swf := &wire.StopWaitingFrame{LeastUnacked: 10} - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() - sph.EXPECT().GetAlarmTimeout().AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAck) - sph.EXPECT().ShouldSendNumPackets().Return(1000) - sph.EXPECT().GetStopWaitingFrame(false).Return(swf) - sph.EXPECT().TimeUntilSend() - sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.Frames).To(HaveLen(2)) - Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.AckFrame{})) - Expect(p.Frames[1]).To(Equal(swf)) - Expect(p.SendTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) - }) - sess.sentPacketHandler = sph - sess.packer.packetNumberGenerator.next = 0x1338 - sess.receivedPacketHandler.ReceivedPacket(1, time.Now(), true) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - sess.run() - close(done) - }() - sess.scheduleSending() - Eventually(mconn.written).Should(HaveLen(1)) - // make sure that the go routine returns - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - streamManager.EXPECT().CloseWithError(gomock.Any()) - sess.Close() - Eventually(done).Should(BeClosed()) - }) - - It("doesn't include a STOP_WAITING for an ACK-only packet for IETF QUIC", func() { - sess.version = versionIETFFrames - done := make(chan struct{}) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() - sph.EXPECT().GetAlarmTimeout().AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAck) - sph.EXPECT().ShouldSendNumPackets().Return(1000) - sph.EXPECT().TimeUntilSend() - sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.Frames).To(HaveLen(1)) - Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.AckFrame{})) - Expect(p.SendTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) - }) - sess.sentPacketHandler = sph - sess.packer.packetNumberGenerator.next = 0x1338 - sess.receivedPacketHandler.ReceivedPacket(1, time.Now(), true) - go func() { - defer GinkgoRecover() - sess.run() - close(done) - }() - sess.scheduleSending() - Eventually(mconn.written).Should(HaveLen(1)) - // make sure that the go routine returns - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - streamManager.EXPECT().CloseWithError(gomock.Any()) - sess.Close() - Eventually(done).Should(BeClosed()) - }) - }) - - Context("retransmissions", func() { - var sph *mockackhandler.MockSentPacketHandler - BeforeEach(func() { - // a STOP_WAITING frame is added, so make sure the packet number of the new package is higher than the packet number of the retransmitted packet - sess.packer.packetNumberGenerator.next = 0x1337 + 10 - sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends - sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() - sess.sentPacketHandler = sph - sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} - }) - - Context("for handshake packets", func() { - It("retransmits an unencrypted packet, and adds a STOP_WAITING frame (for gQUIC)", func() { - sf := &wire.StreamFrame{StreamID: 1, Data: []byte("foobar")} - swf := &wire.StopWaitingFrame{LeastUnacked: 0x1337} - sph.EXPECT().GetStopWaitingFrame(true).Return(swf) - sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ - PacketNumber: 42, - Frames: []wire.Frame{sf}, - EncryptionLevel: protocol.EncryptionUnencrypted, - }) - sph.EXPECT().SentPacketsAsRetransmission(gomock.Any(), protocol.PacketNumber(42)).Do(func(packets []*ackhandler.Packet, _ protocol.PacketNumber) { - Expect(packets).To(HaveLen(1)) - p := packets[0] - Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) - Expect(p.Frames).To(Equal([]wire.Frame{swf, sf})) - Expect(p.SendTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) - }) - sent, err := sess.maybeSendRetransmission() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) - Expect(mconn.written).To(HaveLen(1)) + BeforeEach(func() { + sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetAlarmTimeout().AnyTimes() + sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() + sph.EXPECT().DequeuePacketForRetransmission().AnyTimes() + sess.sentPacketHandler = sph + streamManager.EXPECT().CloseWithError(gomock.Any()) }) - It("retransmits an unencrypted packet, and doesn't add a STOP_WAITING frame (for IETF QUIC)", func() { - sess.version = versionIETFFrames - sess.packer.version = versionIETFFrames - sess.packer.srcConnID = sess.destConnID - sf := &wire.StreamFrame{StreamID: 1, Data: []byte("foobar")} - sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ - PacketNumber: 1337, - Frames: []wire.Frame{sf}, - EncryptionLevel: protocol.EncryptionUnencrypted, - }) - sph.EXPECT().SentPacketsAsRetransmission(gomock.Any(), protocol.PacketNumber(1337)).Do(func(packets []*ackhandler.Packet, _ protocol.PacketNumber) { - Expect(packets).To(HaveLen(1)) - p := packets[0] - Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) - Expect(p.Frames).To(Equal([]wire.Frame{sf})) - Expect(p.SendTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) - }) - sent, err := sess.maybeSendRetransmission() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) - Expect(mconn.written).To(HaveLen(1)) + It("sends multiple packets one by one immediately", func() { + sph.EXPECT().SentPacket(gomock.Any()).Times(2) + sph.EXPECT().ShouldSendNumPackets().Return(1).Times(2) + sph.EXPECT().TimeUntilSend().Return(time.Now()).Times(2) + sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(2) // allow 2 packets... + packer.EXPECT().PackPacket().Return(getPacket(10), nil) + packer.EXPECT().PackPacket().Return(getPacket(11), nil) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() + sess.scheduleSending() + Eventually(mconn.written).Should(HaveLen(2)) + Consistently(mconn.written).Should(HaveLen(2)) + // make the go routine return + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sess.Close() + Eventually(done).Should(BeClosed()) + }) + + // when becoming congestion limited, at some point the SendMode will change from SendAny to SendAck + // we shouldn't send the ACK in the same run + It("doesn't send an ACK right after becoming congestion limited", func() { + sph.EXPECT().SentPacket(gomock.Any()) + sph.EXPECT().ShouldSendNumPackets().Return(1000) + sph.EXPECT().TimeUntilSend().Return(time.Now()) + sph.EXPECT().SendMode().Return(ackhandler.SendAny) + sph.EXPECT().SendMode().Return(ackhandler.SendAck) + packer.EXPECT().PackPacket().Return(getPacket(100), nil) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() + sess.scheduleSending() + Eventually(mconn.written).Should(HaveLen(1)) + Consistently(mconn.written).Should(HaveLen(1)) + // make the go routine return + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sess.Close() + Eventually(done).Should(BeClosed()) + }) + + It("paces packets", func() { + pacingDelay := scaleDuration(100 * time.Millisecond) + sph.EXPECT().SentPacket(gomock.Any()).Times(2) + sph.EXPECT().TimeUntilSend().Return(time.Now().Add(-time.Minute)) // send one packet immediately + sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)) // send one + sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) + sph.EXPECT().ShouldSendNumPackets().Times(2).Return(1) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + packer.EXPECT().PackPacket().Return(getPacket(100), nil) + packer.EXPECT().PackPacket().Return(getPacket(101), nil) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() + sess.scheduleSending() + Eventually(mconn.written).Should(HaveLen(1)) + Consistently(mconn.written, pacingDelay/2).Should(HaveLen(1)) + Eventually(mconn.written, 2*pacingDelay).Should(HaveLen(2)) + // make the go routine return + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sess.Close() + Eventually(done).Should(BeClosed()) + }) + + It("sends multiple packets at once", func() { + sph.EXPECT().SentPacket(gomock.Any()).Times(3) + sph.EXPECT().ShouldSendNumPackets().Return(3) + sph.EXPECT().TimeUntilSend().Return(time.Now()) + sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(3) + packer.EXPECT().PackPacket().Return(getPacket(1000), nil) + packer.EXPECT().PackPacket().Return(getPacket(1001), nil) + packer.EXPECT().PackPacket().Return(getPacket(1002), nil) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() + sess.scheduleSending() + Eventually(mconn.written).Should(HaveLen(3)) + // make the go routine return + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sess.Close() + Eventually(done).Should(BeClosed()) + }) + + It("doesn't set a pacing timer when there is no data to send", func() { + sph.EXPECT().TimeUntilSend().Return(time.Now()) + sph.EXPECT().ShouldSendNumPackets().Return(1) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + packer.EXPECT().PackPacket() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() + sess.scheduleSending() // no packet will get sent + Consistently(mconn.written).ShouldNot(Receive()) + // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + sess.Close() + Eventually(done).Should(BeClosed()) }) }) - Context("for packets after the handshake", func() { - It("sends a STREAM frame from a packet queued for retransmission, and adds a STOP_WAITING (for gQUIC)", func() { - f := &wire.StreamFrame{ - StreamID: 0x5, - Data: []byte("foobar"), - } - swf := &wire.StopWaitingFrame{LeastUnacked: 10} - sph.EXPECT().GetStopWaitingFrame(true).Return(swf) - sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ - PacketNumber: 0x1337, - Frames: []wire.Frame{f}, - EncryptionLevel: protocol.EncryptionForwardSecure, - }) - sph.EXPECT().SentPacketsAsRetransmission(gomock.Any(), protocol.PacketNumber(0x1337)).Do(func(packets []*ackhandler.Packet, _ protocol.PacketNumber) { - Expect(packets).To(HaveLen(1)) - p := packets[0] - Expect(p.Frames).To(HaveLen(2)) - Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.StopWaitingFrame{})) - Expect(p.Frames[1]).To(Equal(f)) - Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) - }) - sent, err := sess.maybeSendRetransmission() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) - Expect(mconn.written).To(HaveLen(1)) + Context("scheduling sending", func() { + It("sends when scheduleSending is called", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetAlarmTimeout().AnyTimes() + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ShouldSendNumPackets().AnyTimes().Return(1) + sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() + sph.EXPECT().SentPacket(gomock.Any()) + sess.sentPacketHandler = sph + packer.EXPECT().PackPacket().Return(getPacket(1), nil) + + go func() { + defer GinkgoRecover() + sess.run() + }() + Consistently(mconn.written).ShouldNot(Receive()) + sess.scheduleSending() + Eventually(mconn.written).Should(Receive()) + // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + sess.Close() + Eventually(sess.Context().Done()).Should(BeClosed()) }) - It("sends a STREAM frame from a packet queued for retransmission, and doesn't add a STOP_WAITING (for IETF QUIC)", func() { - sess.version = versionIETFFrames - sess.packer.version = versionIETFFrames - f := &wire.StreamFrame{ - StreamID: 0x5, - Data: []byte("foobar"), - } - sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ - PacketNumber: 42, - Frames: []wire.Frame{f}, - EncryptionLevel: protocol.EncryptionForwardSecure, + It("sets the timer to the ack timer", func() { + packer.EXPECT().PackPacket().Return(getPacket(1234), nil) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().TimeUntilSend().Return(time.Now()) + sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) + sph.EXPECT().GetAlarmTimeout().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ShouldSendNumPackets().Return(1) + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(1234))) }) - sph.EXPECT().SentPacketsAsRetransmission(gomock.Any(), protocol.PacketNumber(42)).Do(func(packets []*ackhandler.Packet, _ protocol.PacketNumber) { - Expect(packets).To(HaveLen(1)) - p := packets[0] - Expect(p.Frames).To(Equal([]wire.Frame{f})) - Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) - }) - sent, err := sess.maybeSendRetransmission() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) - Expect(mconn.written).To(HaveLen(1)) + sess.sentPacketHandler = sph + rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) + rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond)) + // make the run loop wait + rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour)).MaxTimes(1) + sess.receivedPacketHandler = rph + + go func() { + defer GinkgoRecover() + sess.run() + }() + Eventually(mconn.written).Should(Receive()) + // make sure the go routine returns + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + streamManager.EXPECT().CloseWithError(gomock.Any()) + sess.Close() + Eventually(sess.Context().Done()).Should(BeClosed()) }) - - It("sends multiple packets, if the retransmission is split", func() { - sess.version = versionIETFFrames - sess.packer.version = versionIETFFrames - f := &wire.StreamFrame{ - StreamID: 0x5, - Data: bytes.Repeat([]byte{'b'}, int(protocol.MaxPacketSizeIPv4)*3/2), - } - sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ - PacketNumber: 42, - Frames: []wire.Frame{f}, - EncryptionLevel: protocol.EncryptionForwardSecure, - }) - sph.EXPECT().SentPacketsAsRetransmission(gomock.Any(), protocol.PacketNumber(42)).Do(func(packets []*ackhandler.Packet, _ protocol.PacketNumber) { - Expect(packets).To(HaveLen(2)) - for _, p := range packets { - Expect(p.Frames).To(HaveLen(1)) - Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) - Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) - } - }) - sent, err := sess.maybeSendRetransmission() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) - Expect(mconn.written).To(HaveLen(2)) - }) - }) - }) - - Context("scheduling sending", func() { - BeforeEach(func() { - sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends - sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} - }) - - It("sends when scheduleSending is called", func() { - sess.packer.packetNumberGenerator.next = 10000 - sess.packer.QueueControlFrame(&wire.BlockedFrame{}) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetAlarmTimeout().AnyTimes() - sph.EXPECT().TimeUntilSend().AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ShouldSendNumPackets().AnyTimes().Return(1) - sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any()) - sess.sentPacketHandler = sph - - go func() { - defer GinkgoRecover() - sess.run() - }() - Consistently(mconn.written).ShouldNot(Receive()) - sess.scheduleSending() - Eventually(mconn.written).Should(Receive()) - // make the go routine return - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - streamManager.EXPECT().CloseWithError(gomock.Any()) - sess.Close() - Eventually(sess.Context().Done()).Should(BeClosed()) - }) - - It("sets the timer to the ack timer", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().TimeUntilSend().Return(time.Now()) - sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) - sph.EXPECT().GetAlarmTimeout().AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().GetStopWaitingFrame(gomock.Any()) - sph.EXPECT().ShouldSendNumPackets().Return(1) - sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.AckFrame{})) - Expect(p.Frames[0].(*wire.AckFrame).LargestAcked()).To(Equal(protocol.PacketNumber(0x1337))) - }) - sess.sentPacketHandler = sph - rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) - rph.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 0x1337}}}) - rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond)) - // make the run loop wait - rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour)).MaxTimes(1) - sess.receivedPacketHandler = rph - - go func() { - defer GinkgoRecover() - sess.run() - }() - Eventually(mconn.written).Should(Receive()) - // make sure the go routine returns - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - streamManager.EXPECT().CloseWithError(gomock.Any()) - sess.Close() - Eventually(sess.Context().Done()).Should(BeClosed()) }) }) @@ -1263,6 +1069,7 @@ var _ = Describe("Session", func() { testErr := errors.New("crypto setup error") streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error())) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.handleErr = testErr go func() { defer GinkgoRecover() @@ -1272,7 +1079,7 @@ var _ = Describe("Session", func() { Eventually(sess.Context().Done()).Should(BeClosed()) }) - Context("sending a Public Reset when receiving undecryptable packets during the handshake", func() { + Context("sending a PUBLIC_RESET when receiving undecryptable packets during the handshake", func() { // sends protocol.MaxUndecryptablePackets+1 undecrytable packets // this completely fills up the undecryptable packets queue and triggers the public reset timer sendUndecryptablePackets := func() { @@ -1294,9 +1101,10 @@ var _ = Describe("Session", func() { sess.unpacker = unpacker sess.cryptoStreamHandler = &mockCryptoSetup{} streamManager.EXPECT().CloseWithError(gomock.Any()).MaxTimes(1) + packer.EXPECT().PackPacket().AnyTimes() }) - It("doesn't immediately send a Public Reset after receiving too many undecryptable packets", func() { + It("doesn't immediately send a PUBLIC_RESET after receiving too many undecryptable packets", func() { go func() { defer GinkgoRecover() sess.run() @@ -1306,11 +1114,12 @@ var _ = Describe("Session", func() { Consistently(mconn.written).Should(HaveLen(0)) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) - It("sets a deadline to send a Public Reset after receiving too many undecryptable packets", func() { + It("sets a deadline to send a PUBLIC_RESET after receiving too many undecryptable packets", func() { go func() { defer GinkgoRecover() sess.run() @@ -1319,6 +1128,7 @@ var _ = Describe("Session", func() { Eventually(func() time.Time { return sess.receivedTooManyUndecrytablePacketsTime }).Should(BeTemporally("~", time.Now(), 20*time.Millisecond)) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1334,11 +1144,12 @@ var _ = Describe("Session", func() { Expect(sess.undecryptablePackets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) - It("sends a Public Reset after a timeout", func() { + It("sends a PUBLIC_RESET after a timeout", func() { sessionRunner.EXPECT().removeConnectionID(gomock.Any()) Expect(sess.receivedTooManyUndecrytablePacketsTime).To(BeZero()) go func() { @@ -1356,7 +1167,7 @@ var _ = Describe("Session", func() { Eventually(sess.Context().Done()).Should(BeClosed()) }) - It("doesn't send a Public Reset if decrypting them succeeded during the timeout", func() { + It("doesn't send a PUBLIC_RESET if decrypting them succeeded during the timeout", func() { go func() { defer GinkgoRecover() sess.run() @@ -1369,6 +1180,7 @@ var _ = Describe("Session", func() { Expect(sess.Context().Done()).ToNot(Receive()) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1383,6 +1195,7 @@ var _ = Describe("Session", func() { Consistently(sess.undecryptablePackets).Should(BeEmpty()) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1408,21 +1221,23 @@ var _ = Describe("Session", func() { // make sure the go routine returns sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) It("calls the onHandshakeComplete callback when the handshake completes", func() { + close(handshakeChan) + sessionRunner.EXPECT().onHandshakeComplete(gomock.Any()) go func() { defer GinkgoRecover() sess.run() }() - sessionRunner.EXPECT().onHandshakeComplete(gomock.Any()) - close(handshakeChan) Consistently(sess.Context().Done()).ShouldNot(BeClosed()) // make sure the go routine returns sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1437,6 +1252,7 @@ var _ = Describe("Session", func() { }() streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) Expect(sess.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) @@ -1452,6 +1268,7 @@ var _ = Describe("Session", func() { }() streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) Expect(sess.CloseWithError(0x1337, testErr)).To(Succeed()) Eventually(done).Should(BeClosed()) }) @@ -1472,13 +1289,14 @@ var _ = Describe("Session", func() { MaxPacketSize: 0x42, } streamManager.EXPECT().UpdateLimits(¶ms) + packer.EXPECT().SetOmitConnectionID() + packer.EXPECT().SetMaxPacketSize(protocol.ByteCount(0x42)) paramsChan <- params Eventually(func() *handshake.TransportParameters { return sess.peerParams }).Should(Equal(¶ms)) - Eventually(func() bool { return sess.packer.omitConnectionID }).Should(BeTrue()) - Eventually(func() protocol.ByteCount { return sess.packer.maxPacketSize }).Should(Equal(protocol.ByteCount(0x42))) // make the go routine return streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1496,20 +1314,23 @@ var _ = Describe("Session", func() { sess.handshakeComplete = true sess.config.KeepAlive = true sess.lastNetworkActivityTime = time.Now().Add(-remoteIdleTimeout / 2) - sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends + sent := make(chan struct{}) + packer.EXPECT().QueueControlFrame(&wire.PingFrame{}) + packer.EXPECT().PackPacket().Do(func() (*packedPacket, error) { + close(sent) + return nil, nil + }) done := make(chan struct{}) go func() { defer GinkgoRecover() sess.run() close(done) }() - var data []byte - Eventually(mconn.written).Should(Receive(&data)) - // -12 because of the crypto tag. This should be 7 (the frame id for a ping frame). - Expect(data[len(data)-12-1 : len(data)-12]).To(Equal([]byte{0x07})) + Eventually(sent).Should(BeClosed()) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.Close() Eventually(done).Should(BeClosed()) }) @@ -1528,6 +1349,7 @@ var _ = Describe("Session", func() { // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.Close() Eventually(done).Should(BeClosed()) }) @@ -1546,6 +1368,7 @@ var _ = Describe("Session", func() { // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.Close() Eventually(done).Should(BeClosed()) }) @@ -1561,6 +1384,10 @@ var _ = Describe("Session", func() { sess.handshakeComplete = true sess.lastNetworkActivityTime = time.Now().Add(-time.Hour) done := make(chan struct{}) + packer.EXPECT().PackConnectionClose(gomock.Any()).DoAndReturn(func(f *wire.ConnectionCloseFrame) (*packedPacket, error) { + Expect(f.ErrorCode).To(Equal(qerr.NetworkIdleTimeout)) + return &packedPacket{}, nil + }) go func() { defer GinkgoRecover() err := sess.run() @@ -1568,12 +1395,15 @@ var _ = Describe("Session", func() { close(done) }() Eventually(done).Should(BeClosed()) - Expect(mconn.written).To(Receive(ContainSubstring("No recent network activity."))) }) It("times out due to non-completed handshake", func() { sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.sessionCreationTime = time.Now().Add(-protocol.DefaultHandshakeTimeout).Add(-time.Second) + packer.EXPECT().PackConnectionClose(gomock.Any()).DoAndReturn(func(f *wire.ConnectionCloseFrame) (*packedPacket, error) { + Expect(f.ErrorCode).To(Equal(qerr.HandshakeTimeout)) + return &packedPacket{}, nil + }) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -1582,13 +1412,16 @@ var _ = Describe("Session", func() { close(done) }() Eventually(done).Should(BeClosed()) - Expect(mconn.written).To(Receive(ContainSubstring("Crypto handshake did not complete in time."))) }) It("does not use the idle timeout before the handshake complete", func() { sess.config.IdleTimeout = 9999 * time.Second defer sess.Close() sess.lastNetworkActivityTime = time.Now().Add(-time.Minute) + packer.EXPECT().PackConnectionClose(gomock.Any()).DoAndReturn(func(f *wire.ConnectionCloseFrame) (*packedPacket, error) { + Expect(f.ErrorCode).To(Equal(qerr.PeerGoingAway)) + return &packedPacket{}, nil + }) // the handshake timeout is irrelevant here, since it depends on the time the session was created, // and not on the last network activity go func() { @@ -1605,6 +1438,10 @@ var _ = Describe("Session", func() { It("closes the session due to the idle timeout after handshake", func() { sessionRunner.EXPECT().onHandshakeComplete(sess) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).DoAndReturn(func(f *wire.ConnectionCloseFrame) (*packedPacket, error) { + Expect(f.ErrorCode).To(Equal(qerr.NetworkIdleTimeout)) + return &packedPacket{}, nil + }) sess.config.IdleTimeout = 0 close(handshakeChan) done := make(chan struct{}) @@ -1615,7 +1452,6 @@ var _ = Describe("Session", func() { close(done) }() Eventually(done).Should(BeClosed()) - Expect(mconn.written).To(Receive(ContainSubstring("No recent network activity."))) }) }) @@ -1718,6 +1554,7 @@ var _ = Describe("Client Session", func() { var ( sess *session sessionRunner *MockSessionRunner + packer *MockPacker mconn *mockConnection handshakeChan chan<- struct{} @@ -1762,6 +1599,8 @@ var _ = Describe("Client Session", func() { ) sess = sessP.(*session) Expect(err).ToNot(HaveOccurred()) + packer = NewMockPacker(mockCtrl) + sess.packer = packer }) AfterEach(func() { @@ -1769,24 +1608,31 @@ var _ = Describe("Client Session", func() { }) It("sends a forward-secure packet when the handshake completes", func() { - sessionRunner.EXPECT().onHandshakeComplete(gomock.Any()) - sess.packer.hasSentPacket = true + done := make(chan struct{}) + gomock.InOrder( + sessionRunner.EXPECT().onHandshakeComplete(gomock.Any()), + packer.EXPECT().QueueControlFrame(&wire.PingFrame{}), + packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) { + close(done) + return &packedPacket{header: &wire.Header{}, raw: *getPacketBuffer()}, nil + }), + packer.EXPECT().PackPacket().AnyTimes(), + ) + close(handshakeChan) go func() { defer GinkgoRecover() sess.run() }() - close(handshakeChan) - Eventually(mconn.written).Should(Receive()) + Eventually(done).Should(BeClosed()) //make sure the go routine returns sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) It("changes the connection ID when receiving the first packet from the server", func() { sess.version = protocol.VersionTLS - sess.packer.version = protocol.VersionTLS - sess.packer.srcConnID = sess.destConnID unpacker := NewMockUnpacker(mockCtrl) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil) sess.unpacker = unpacker @@ -1794,25 +1640,20 @@ var _ = Describe("Client Session", func() { defer GinkgoRecover() sess.run() }() + newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7} + packer.EXPECT().ChangeDestConnectionID(newConnID) err := sess.handlePacketImpl(&receivedPacket{ header: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - SrcConnectionID: protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7}, + SrcConnectionID: newConnID, DestConnectionID: sess.srcConnID, }, data: []byte{0}, }) Expect(err).ToNot(HaveOccurred()) - // the session should have changed the dest connection ID now - sess.packer.hasSentPacket = true - sess.queueControlFrame(&wire.PingFrame{}) - var packet []byte - Eventually(mconn.written).Should(Receive(&packet)) - hdr, err := wire.ParseInvariantHeader(bytes.NewReader(packet), 0) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7})) // make sure the go routine returns + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) @@ -1841,6 +1682,7 @@ var _ = Describe("Client Session", func() { Expect(err).ToNot(HaveOccurred()) Expect(cryptoSetup.divNonce).To(Equal(hdr.DiversificationNonce)) // make the go routine return + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) diff --git a/stream_framer.go b/stream_framer.go index aabfac9fe..24247a842 100644 --- a/stream_framer.go +++ b/stream_framer.go @@ -61,14 +61,13 @@ func (f *streamFramer) PopCryptoStreamFrame(maxLen protocol.ByteCount) *wire.Str return frame } -func (f *streamFramer) PopStreamFrames(maxTotalLen protocol.ByteCount) []*wire.StreamFrame { - var currentLen protocol.ByteCount - var frames []*wire.StreamFrame +func (f *streamFramer) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCount) []wire.Frame { + var length protocol.ByteCount f.streamQueueMutex.Lock() // pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet numActiveStreams := len(f.streamQueue) for i := 0; i < numActiveStreams; i++ { - if maxTotalLen-currentLen < protocol.MinStreamFrameSize { + if maxLen-length < protocol.MinStreamFrameSize { break } id := f.streamQueue[0] @@ -81,7 +80,7 @@ func (f *streamFramer) PopStreamFrames(maxTotalLen protocol.ByteCount) []*wire.S delete(f.activeStreams, id) continue } - frame, hasMoreData := str.popStreamFrame(maxTotalLen - currentLen) + frame, hasMoreData := str.popStreamFrame(maxLen - length) if hasMoreData { // put the stream back in the queue (at the end) f.streamQueue = append(f.streamQueue, id) } else { // no more data to send. Stream is not active any more @@ -91,7 +90,7 @@ func (f *streamFramer) PopStreamFrames(maxTotalLen protocol.ByteCount) []*wire.S continue } frames = append(frames, frame) - currentLen += frame.Length(f.version) + length += frame.Length(f.version) } f.streamQueueMutex.Unlock() return frames diff --git a/stream_framer_test.go b/stream_framer_test.go index fff5e2231..edca4d410 100644 --- a/stream_framer_test.go +++ b/stream_framer_test.go @@ -68,7 +68,7 @@ var _ = Describe("Stream Framer", func() { Context("Popping", func() { It("returns nil when popping an empty framer", func() { - Expect(framer.PopStreamFrames(1000)).To(BeEmpty()) + Expect(framer.AppendStreamFrames(nil, 1000)).To(BeEmpty()) }) It("returns STREAM frames", func() { @@ -80,8 +80,22 @@ var _ = Describe("Stream Framer", func() { } stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) framer.AddActiveStream(id1) - fs := framer.PopStreamFrames(1000) - Expect(fs).To(Equal([]*wire.StreamFrame{f})) + fs := framer.AppendStreamFrames(nil, 1000) + Expect(fs).To(Equal([]wire.Frame{f})) + }) + + It("appends to a frame slice", func() { + streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) + f := &wire.StreamFrame{ + StreamID: id1, + Data: []byte("foobar"), + } + stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) + framer.AddActiveStream(id1) + mdf := &wire.MaxDataFrame{ByteOffset: 1337} + frames := []wire.Frame{mdf} + fs := framer.AppendStreamFrames(frames, 1000) + Expect(fs).To(Equal([]wire.Frame{mdf, f})) }) It("skips a stream that was reported active, but was completed shortly after", func() { @@ -94,7 +108,7 @@ var _ = Describe("Stream Framer", func() { stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) framer.AddActiveStream(id1) framer.AddActiveStream(id2) - Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f})) + Expect(framer.AppendStreamFrames(nil, 1000)).To(Equal([]wire.Frame{f})) }) It("skips a stream that was reported active, but doesn't have any data", func() { @@ -108,7 +122,7 @@ var _ = Describe("Stream Framer", func() { stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) framer.AddActiveStream(id1) framer.AddActiveStream(id2) - Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f})) + Expect(framer.AppendStreamFrames(nil, 1000)).To(Equal([]wire.Frame{f})) }) It("pops from a stream multiple times, if it has enough data", func() { @@ -118,10 +132,10 @@ var _ = Describe("Stream Framer", func() { stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f1, true) stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false) framer.AddActiveStream(id1) // only add it once - Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f1})) - Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f2})) + Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(Equal([]wire.Frame{f1})) + Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(Equal([]wire.Frame{f2})) // no further calls to popStreamFrame, after popStreamFrame said there's no more data - Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(BeNil()) + Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(BeNil()) }) It("re-queues a stream at the end, if it has enough data", func() { @@ -135,9 +149,12 @@ var _ = Describe("Stream Framer", func() { stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false) framer.AddActiveStream(id1) // only add it once framer.AddActiveStream(id2) - Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f11})) // first a frame from stream 1 - Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f2})) // then a frame from stream 2 - Expect(framer.PopStreamFrames(protocol.MinStreamFrameSize)).To(Equal([]*wire.StreamFrame{f12})) // then another frame from stream 1 + // first a frame from stream 1 + Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(Equal([]wire.Frame{f11})) + // then a frame from stream 2 + Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(Equal([]wire.Frame{f2})) + // then another frame from stream 1 + Expect(framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize)).To(Equal([]wire.Frame{f12})) }) It("only dequeues data from each stream once per packet", func() { @@ -150,7 +167,7 @@ var _ = Describe("Stream Framer", func() { stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, true) framer.AddActiveStream(id1) framer.AddActiveStream(id2) - Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f1, f2})) + Expect(framer.AppendStreamFrames(nil, 1000)).To(Equal([]wire.Frame{f1, f2})) }) It("returns multiple normal frames in the order they were reported active", func() { @@ -162,7 +179,7 @@ var _ = Describe("Stream Framer", func() { stream2.EXPECT().popStreamFrame(gomock.Any()).Return(f2, false) framer.AddActiveStream(id2) framer.AddActiveStream(id1) - Expect(framer.PopStreamFrames(1000)).To(Equal([]*wire.StreamFrame{f2, f1})) + Expect(framer.AppendStreamFrames(nil, 1000)).To(Equal([]wire.Frame{f2, f1})) }) It("only asks a stream for data once, even if it was reported active multiple times", func() { @@ -171,11 +188,11 @@ var _ = Describe("Stream Framer", func() { stream1.EXPECT().popStreamFrame(gomock.Any()).Return(f, false) // only one call to this function framer.AddActiveStream(id1) framer.AddActiveStream(id1) - Expect(framer.PopStreamFrames(1000)).To(HaveLen(1)) + Expect(framer.AppendStreamFrames(nil, 1000)).To(HaveLen(1)) }) It("does not pop empty frames", func() { - fs := framer.PopStreamFrames(500) + fs := framer.AppendStreamFrames(nil, 500) Expect(fs).To(BeEmpty()) }) @@ -183,12 +200,12 @@ var _ = Describe("Stream Framer", func() { streamGetter.EXPECT().GetOrOpenSendStream(id1).Return(stream1, nil) stream1.EXPECT().popStreamFrame(protocol.MinStreamFrameSize).Return(&wire.StreamFrame{Data: []byte("foobar")}, false) framer.AddActiveStream(id1) - framer.PopStreamFrames(protocol.MinStreamFrameSize) + framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize) }) It("does not pop frames smaller than the minimum size", func() { // don't expect a call to PopStreamFrame() - framer.PopStreamFrames(protocol.MinStreamFrameSize - 1) + framer.AppendStreamFrames(nil, protocol.MinStreamFrameSize-1) }) It("stops iterating when the remaining size is smaller than the minimum STREAM frame size", func() { @@ -200,8 +217,8 @@ var _ = Describe("Stream Framer", func() { } stream1.EXPECT().popStreamFrame(protocol.ByteCount(500)).Return(f, false) framer.AddActiveStream(id1) - fs := framer.PopStreamFrames(500) - Expect(fs).To(Equal([]*wire.StreamFrame{f})) + fs := framer.AppendStreamFrames(nil, 500) + Expect(fs).To(Equal([]wire.Frame{f})) }) }) })