diff --git a/mock_ack_frame_source_test.go b/mock_ack_frame_source_test.go new file mode 100644 index 00000000..1f3cd578 --- /dev/null +++ b/mock_ack_frame_source_test.go @@ -0,0 +1,59 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: AckFrameSource) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockAckFrameSource is a mock of AckFrameSource interface +type MockAckFrameSource struct { + ctrl *gomock.Controller + recorder *MockAckFrameSourceMockRecorder +} + +// MockAckFrameSourceMockRecorder is the mock recorder for MockAckFrameSource +type MockAckFrameSourceMockRecorder struct { + mock *MockAckFrameSource +} + +// NewMockAckFrameSource creates a new mock instance +func NewMockAckFrameSource(ctrl *gomock.Controller) *MockAckFrameSource { + mock := &MockAckFrameSource{ctrl: ctrl} + mock.recorder = &MockAckFrameSourceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockAckFrameSource) EXPECT() *MockAckFrameSourceMockRecorder { + return m.recorder +} + +// GetAckFrame mocks base method +func (m *MockAckFrameSource) GetAckFrame() *wire.AckFrame { + ret := m.ctrl.Call(m, "GetAckFrame") + ret0, _ := ret[0].(*wire.AckFrame) + return ret0 +} + +// GetAckFrame indicates an expected call of GetAckFrame +func (mr *MockAckFrameSourceMockRecorder) GetAckFrame() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockAckFrameSource)(nil).GetAckFrame)) +} + +// GetStopWaitingFrame mocks base method +func (m *MockAckFrameSource) GetStopWaitingFrame(arg0 bool) *wire.StopWaitingFrame { + ret := m.ctrl.Call(m, "GetStopWaitingFrame", arg0) + ret0, _ := ret[0].(*wire.StopWaitingFrame) + return ret0 +} + +// GetStopWaitingFrame indicates an expected call of GetStopWaitingFrame +func (mr *MockAckFrameSourceMockRecorder) GetStopWaitingFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStopWaitingFrame", reflect.TypeOf((*MockAckFrameSource)(nil).GetStopWaitingFrame), arg0) +} diff --git a/mock_packer_test.go b/mock_packer_test.go new file mode 100644 index 00000000..a9e49a00 --- /dev/null +++ b/mock_packer_test.go @@ -0,0 +1,129 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: Packer) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + ackhandler "github.com/lucas-clemente/quic-go/internal/ackhandler" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" + wire "github.com/lucas-clemente/quic-go/internal/wire" +) + +// MockPacker is a mock of Packer interface +type MockPacker struct { + ctrl *gomock.Controller + recorder *MockPackerMockRecorder +} + +// MockPackerMockRecorder is the mock recorder for MockPacker +type MockPackerMockRecorder struct { + mock *MockPacker +} + +// NewMockPacker creates a new mock instance +func NewMockPacker(ctrl *gomock.Controller) *MockPacker { + mock := &MockPacker{ctrl: ctrl} + mock.recorder = &MockPackerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockPacker) EXPECT() *MockPackerMockRecorder { + return m.recorder +} + +// ChangeDestConnectionID mocks base method +func (m *MockPacker) ChangeDestConnectionID(arg0 protocol.ConnectionID) { + m.ctrl.Call(m, "ChangeDestConnectionID", arg0) +} + +// ChangeDestConnectionID indicates an expected call of ChangeDestConnectionID +func (mr *MockPackerMockRecorder) ChangeDestConnectionID(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeDestConnectionID", reflect.TypeOf((*MockPacker)(nil).ChangeDestConnectionID), arg0) +} + +// MaybePackAckPacket mocks base method +func (m *MockPacker) MaybePackAckPacket() (*packedPacket, error) { + ret := m.ctrl.Call(m, "MaybePackAckPacket") + ret0, _ := ret[0].(*packedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MaybePackAckPacket indicates an expected call of MaybePackAckPacket +func (mr *MockPackerMockRecorder) MaybePackAckPacket() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MaybePackAckPacket", reflect.TypeOf((*MockPacker)(nil).MaybePackAckPacket)) +} + +// PackConnectionClose mocks base method +func (m *MockPacker) PackConnectionClose(arg0 *wire.ConnectionCloseFrame) (*packedPacket, error) { + ret := m.ctrl.Call(m, "PackConnectionClose", arg0) + ret0, _ := ret[0].(*packedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PackConnectionClose indicates an expected call of PackConnectionClose +func (mr *MockPackerMockRecorder) PackConnectionClose(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackConnectionClose", reflect.TypeOf((*MockPacker)(nil).PackConnectionClose), arg0) +} + +// PackPacket mocks base method +func (m *MockPacker) PackPacket() (*packedPacket, error) { + ret := m.ctrl.Call(m, "PackPacket") + ret0, _ := ret[0].(*packedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PackPacket indicates an expected call of PackPacket +func (mr *MockPackerMockRecorder) PackPacket() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackPacket", reflect.TypeOf((*MockPacker)(nil).PackPacket)) +} + +// PackRetransmission mocks base method +func (m *MockPacker) PackRetransmission(arg0 *ackhandler.Packet) ([]*packedPacket, error) { + ret := m.ctrl.Call(m, "PackRetransmission", arg0) + ret0, _ := ret[0].([]*packedPacket) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PackRetransmission indicates an expected call of PackRetransmission +func (mr *MockPackerMockRecorder) PackRetransmission(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackRetransmission", reflect.TypeOf((*MockPacker)(nil).PackRetransmission), arg0) +} + +// QueueControlFrame mocks base method +func (m *MockPacker) QueueControlFrame(arg0 wire.Frame) { + m.ctrl.Call(m, "QueueControlFrame", arg0) +} + +// QueueControlFrame indicates an expected call of QueueControlFrame +func (mr *MockPackerMockRecorder) QueueControlFrame(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueControlFrame", reflect.TypeOf((*MockPacker)(nil).QueueControlFrame), arg0) +} + +// SetMaxPacketSize mocks base method +func (m *MockPacker) SetMaxPacketSize(arg0 protocol.ByteCount) { + m.ctrl.Call(m, "SetMaxPacketSize", arg0) +} + +// SetMaxPacketSize indicates an expected call of SetMaxPacketSize +func (mr *MockPackerMockRecorder) SetMaxPacketSize(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetMaxPacketSize", reflect.TypeOf((*MockPacker)(nil).SetMaxPacketSize), arg0) +} + +// SetOmitConnectionID mocks base method +func (m *MockPacker) SetOmitConnectionID() { + m.ctrl.Call(m, "SetOmitConnectionID") +} + +// SetOmitConnectionID indicates an expected call of SetOmitConnectionID +func (mr *MockPackerMockRecorder) SetOmitConnectionID() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetOmitConnectionID", reflect.TypeOf((*MockPacker)(nil).SetOmitConnectionID)) +} diff --git a/mockgen.go b/mockgen.go index 5b7cd4f0..3667134d 100644 --- a/mockgen.go +++ b/mockgen.go @@ -6,9 +6,11 @@ package quic //go:generate sh -c "./mockgen_private.sh quic mock_stream_sender_test.go github.com/lucas-clemente/quic-go streamSender" //go:generate sh -c "./mockgen_private.sh quic mock_stream_getter_test.go github.com/lucas-clemente/quic-go streamGetter" //go:generate sh -c "./mockgen_private.sh quic mock_stream_frame_source_test.go github.com/lucas-clemente/quic-go streamFrameSource" +//go:generate sh -c "./mockgen_private.sh quic mock_ack_frame_source_test.go github.com/lucas-clemente/quic-go ackFrameSource" //go:generate sh -c "./mockgen_private.sh quic mock_crypto_stream_test.go github.com/lucas-clemente/quic-go cryptoStream" //go:generate sh -c "./mockgen_private.sh quic mock_stream_manager_test.go github.com/lucas-clemente/quic-go streamManager" //go:generate sh -c "./mockgen_private.sh quic mock_unpacker_test.go github.com/lucas-clemente/quic-go unpacker" +//go:generate sh -c "./mockgen_private.sh quic mock_packer_test.go github.com/lucas-clemente/quic-go packer" //go:generate sh -c "./mockgen_private.sh quic mock_quic_aead_test.go github.com/lucas-clemente/quic-go quicAEAD" //go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD" //go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner" diff --git a/packet_packer.go b/packet_packer.go index bda585d9..538bc3e0 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -15,6 +15,19 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) +type packer interface { + QueueControlFrame(frame wire.Frame) + + PackPacket() (*packedPacket, error) + MaybePackAckPacket() (*packedPacket, error) + PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error) + PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error) + + SetOmitConnectionID() + ChangeDestConnectionID(protocol.ConnectionID) + SetMaxPacketSize(protocol.ByteCount) +} + type packedPacket struct { header *wire.Header raw []byte @@ -62,6 +75,19 @@ type streamFrameSource interface { AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame } +// sentAndReceivedPacketManager is only needed until STOP_WAITING is removed +type sentAndReceivedPacketManager struct { + ackhandler.SentPacketHandler + ackhandler.ReceivedPacketHandler +} + +var _ ackFrameSource = &sentAndReceivedPacketManager{} + +type ackFrameSource interface { + GetAckFrame() *wire.AckFrame + GetStopWaitingFrame(bool) *wire.StopWaitingFrame +} + type packetPacker struct { destConnID protocol.ConnectionID srcConnID protocol.ConnectionID @@ -76,18 +102,19 @@ type packetPacker struct { packetNumberGenerator *packetNumberGenerator getPacketNumberLen func(protocol.PacketNumber) protocol.PacketNumberLen streams streamFrameSource + acks ackFrameSource controlFrameMutex sync.Mutex controlFrames []wire.Frame - stopWaiting *wire.StopWaitingFrame - ackFrame *wire.AckFrame omitConnectionID bool maxPacketSize protocol.ByteCount hasSentPacket bool // has the packetPacker already sent a packet numNonRetransmittableAcks int } +var _ packer = &packetPacker{} + func newPacketPacker( destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, @@ -98,6 +125,7 @@ func newPacketPacker( divNonce []byte, cryptoSetup sealingManager, streamFramer streamFrameSource, + acks ackFrameSource, perspective protocol.Perspective, version protocol.VersionNumber, ) *packetPacker { @@ -110,6 +138,7 @@ func newPacketPacker( perspective: perspective, version: version, streams: streamFramer, + acks: acks, getPacketNumberLen: getPacketNumberLen, packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength), maxPacketSize: getMaxPacketSize(remoteAddr), @@ -130,20 +159,22 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac }, err } -func (p *packetPacker) PackAckPacket() (*packedPacket, error) { - if p.ackFrame == nil { - return nil, errors.New("packet packer BUG: no ack frame queued") +func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { + ack := p.acks.GetAckFrame() + if ack == nil { + return nil, nil } encLevel, sealer := p.cryptoSetup.GetSealer() header := p.getHeader(encLevel) - frames := []wire.Frame{p.ackFrame} - if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC - p.stopWaiting.PacketNumber = header.PacketNumber - p.stopWaiting.PacketNumberLen = header.PacketNumberLen - frames = append(frames, p.stopWaiting) - p.stopWaiting = nil + frames := []wire.Frame{ack} + // add a STOP_WAITING frame, when using gQUIC + if p.version.UsesStopWaitingFrames() { + if swf := p.acks.GetStopWaitingFrame(false); swf != nil { + swf.PacketNumber = header.PacketNumber + swf.PacketNumberLen = header.PacketNumberLen + frames = append(frames, swf) + } } - p.ackFrame = nil raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ header: header, @@ -175,6 +206,11 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP var packets []*packedPacket encLevel, sealer := p.cryptoSetup.GetSealer() + var swf *wire.StopWaitingFrame + // for gQUIC: add a STOP_WAITING for *every* retransmission + if p.version.UsesStopWaitingFrames() { + swf = p.acks.GetStopWaitingFrame(true) + } for len(controlFrames) > 0 || len(streamFrames) > 0 { var frames []wire.Frame var length protocol.ByteCount @@ -186,19 +222,15 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP } maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength - // for gQUIC: add a STOP_WAITING for *every* retransmission if p.version.UsesStopWaitingFrames() { - if p.stopWaiting == nil { - return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame") - } - // create a new StopWaitingFrame, since we might need to send more than one packet as a retransmission - swf := &wire.StopWaitingFrame{ - LeastUnacked: p.stopWaiting.LeastUnacked, + // create a new STOP_WAIITNG Frame, since we might need to send more than one packet as a retransmission + stopWaitingFrame := &wire.StopWaitingFrame{ + LeastUnacked: swf.LeastUnacked, PacketNumber: header.PacketNumber, PacketNumberLen: header.PacketNumberLen, } - length += swf.Length(p.version) - frames = append(frames, swf) + length += stopWaitingFrame.Length(p.version) + frames = append(frames, stopWaitingFrame) } for len(controlFrames) > 0 { @@ -253,7 +285,6 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP encryptionLevel: encLevel, }) } - p.stopWaiting = nil return packets, nil } @@ -271,13 +302,9 @@ func (p *packetPacker) packHandshakeRetransmission(packet *ackhandler.Packet) (* header.Type = packet.PacketType var frames []wire.Frame if p.version.UsesStopWaitingFrames() { // for gQUIC: pack a STOP_WAITING first - if p.stopWaiting == nil { - return nil, errors.New("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame") - } - swf := p.stopWaiting + swf := p.acks.GetStopWaitingFrame(true) swf.PacketNumber = header.PacketNumber swf.PacketNumberLen = header.PacketNumberLen - p.stopWaiting = nil frames = append([]wire.Frame{swf}, packet.Frames...) } else { frames = packet.Frames @@ -310,13 +337,9 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { if err != nil { return nil, err } - if p.stopWaiting != nil { - p.stopWaiting.PacketNumber = header.PacketNumber - p.stopWaiting.PacketNumberLen = header.PacketNumberLen - } maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLength - frames, err := p.composeNextPacket(maxSize, p.canSendData(encLevel)) + frames, err := p.composeNextPacket(header, maxSize, p.canSendData(encLevel)) if err != nil { return nil, err } @@ -325,25 +348,17 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { if len(frames) == 0 { return nil, nil } - // Don't send out packets that only contain a StopWaitingFrame - if len(frames) == 1 && p.stopWaiting != nil { - return nil, nil - } - if p.ackFrame != nil { - // check if this packet only contains an ACK (and maybe a STOP_WAITING) - if len(frames) == 1 || (p.stopWaiting != nil && len(frames) == 2) { - if p.numNonRetransmittableAcks >= protocol.MaxNonRetransmittableAcks { - frames = append(frames, &wire.PingFrame{}) - p.numNonRetransmittableAcks = 0 - } else { - p.numNonRetransmittableAcks++ - } - } else { + // check if this packet only contains an ACK (and maybe a STOP_WAITING) + if !ackhandler.HasRetransmittableFrames(frames) { + if p.numNonRetransmittableAcks >= protocol.MaxNonRetransmittableAcks { + frames = append(frames, &wire.PingFrame{}) p.numNonRetransmittableAcks = 0 + } else { + p.numNonRetransmittableAcks++ } + } else { + p.numNonRetransmittableAcks = 0 } - p.stopWaiting = nil - p.ackFrame = nil raw, err := p.writeAndSealPacket(header, frames, sealer) if err != nil { @@ -381,6 +396,7 @@ func (p *packetPacker) packCryptoPacket() (*packedPacket, error) { } func (p *packetPacker) composeNextPacket( + header *wire.Header, // only needed to fill in the STOP_WAITING frame maxFrameSize protocol.ByteCount, canSendStreamFrames bool, ) ([]wire.Frame, error) { @@ -388,13 +404,19 @@ func (p *packetPacker) composeNextPacket( var frames []wire.Frame // STOP_WAITING and ACK will always fit - if p.ackFrame != nil { // ACKs need to go first, so that the sentPacketHandler will recognize them - frames = append(frames, p.ackFrame) - length += p.ackFrame.Length(p.version) - } - if p.stopWaiting != nil { // a STOP_WAITING will only be queued when using gQUIC - frames = append(frames, p.stopWaiting) - length += p.stopWaiting.Length(p.version) + // ACKs need to go first, so that the sentPacketHandler will recognize them + if ack := p.acks.GetAckFrame(); ack != nil { + frames = append(frames, ack) + length += ack.Length(p.version) + // add a STOP_WAITING, for gQUIC + if p.version.UsesStopWaitingFrames() { + if swf := p.acks.GetStopWaitingFrame(false); swf != nil { + swf.PacketNumber = header.PacketNumber + swf.PacketNumberLen = header.PacketNumberLen + frames = append(frames, swf) + length += swf.Length(p.version) + } + } } p.controlFrameMutex.Lock() @@ -410,10 +432,6 @@ func (p *packetPacker) composeNextPacket( } p.controlFrameMutex.Unlock() - if length > maxFrameSize { - return nil, fmt.Errorf("Packet Packer BUG: packet payload (%d) too large (%d)", length, maxFrameSize) - } - if !canSendStreamFrames { return frames, nil } @@ -440,16 +458,9 @@ func (p *packetPacker) composeNextPacket( } func (p *packetPacker) QueueControlFrame(frame wire.Frame) { - switch f := frame.(type) { - case *wire.StopWaitingFrame: - p.stopWaiting = f - case *wire.AckFrame: - p.ackFrame = f - default: - p.controlFrameMutex.Lock() - p.controlFrames = append(p.controlFrames, f) - p.controlFrameMutex.Unlock() - } + p.controlFrameMutex.Lock() + p.controlFrames = append(p.controlFrames, frame) + p.controlFrameMutex.Unlock() } func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header { diff --git a/packet_packer_test.go b/packet_packer_test.go index 02cba9ed..df791775 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -60,6 +60,7 @@ var _ = Describe("Packet packer", func() { publicHeaderLen protocol.ByteCount maxFrameSize protocol.ByteCount mockStreamFramer *MockStreamFrameSource + mockAckFramer *MockAckFrameSource divNonce []byte token []byte ) @@ -84,6 +85,7 @@ var _ = Describe("Packet packer", func() { mockSender := NewMockStreamSender(mockCtrl) mockSender.EXPECT().onHasStreamData(gomock.Any()).AnyTimes() mockStreamFramer = NewMockStreamFrameSource(mockCtrl) + mockAckFramer = NewMockAckFrameSource(mockCtrl) divNonce = bytes.Repeat([]byte{'e'}, 32) token = []byte("initial token") @@ -97,6 +99,7 @@ var _ = Describe("Packet packer", func() { divNonce, &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}, mockStreamFramer, + mockAckFramer, protocol.PerspectiveServer, version, ) @@ -125,6 +128,7 @@ var _ = Describe("Packet packer", func() { }) It("returns nil when no packet is queued", func() { + mockAckFramer.EXPECT().GetAckFrame() mockStreamFramer.EXPECT().HasCryptoStreamData() mockStreamFramer.EXPECT().AppendStreamFrames(nil, gomock.Any()) p, err := packer.PackPacket() @@ -134,6 +138,7 @@ var _ = Describe("Packet packer", func() { It("packs single packets", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() f := &wire.StreamFrame{ StreamID: 5, Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, @@ -150,6 +155,7 @@ var _ = Describe("Packet packer", func() { It("stores the encryption level a packet was sealed with", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() expectAppendStreamFrames(&wire.StreamFrame{ StreamID: 5, Data: []byte("foobar"), @@ -356,6 +362,7 @@ var _ = Describe("Packet packer", func() { It("packs only control frames", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() expectAppendStreamFrames() packer.QueueControlFrame(&wire.RstStreamFrame{}) packer.QueueControlFrame(&wire.MaxDataFrame{}) @@ -368,6 +375,7 @@ var _ = Describe("Packet packer", func() { It("increases the packet number", func() { mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockAckFramer.EXPECT().GetAckFrame().Times(2) expectAppendStreamFrames() expectAppendStreamFrames() packer.QueueControlFrame(&wire.RstStreamFrame{}) @@ -381,49 +389,33 @@ var _ = Describe("Packet packer", func() { Expect(p2.header.PacketNumber).To(BeNumerically(">", p1.header.PacketNumber)) }) - It("packs a STOP_WAITING frame first", func() { + It("packs ACKs and STOP_WAITING frames first", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100}}} + swf := &wire.StopWaitingFrame{LeastUnacked: 10} + mockAckFramer.EXPECT().GetAckFrame().Return(ack) + mockAckFramer.EXPECT().GetStopWaitingFrame(false).Return(swf) expectAppendStreamFrames() packer.packetNumberGenerator.next = 15 - swf := &wire.StopWaitingFrame{LeastUnacked: 10} - packer.QueueControlFrame(&wire.RstStreamFrame{}) - packer.QueueControlFrame(swf) + cf := &wire.RstStreamFrame{} + packer.QueueControlFrame(cf) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) - Expect(p.frames).To(HaveLen(2)) - Expect(p.frames[0]).To(Equal(swf)) + Expect(p.frames).To(Equal([]wire.Frame{ack, swf, cf})) }) It("sets the LeastUnackedDelta length of a STOP_WAITING frame", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() + swf := &wire.StopWaitingFrame{LeastUnacked: 0x1337 - 0x100} + mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 100}}}) + mockAckFramer.EXPECT().GetStopWaitingFrame(false).Return(swf) expectAppendStreamFrames() packer.packetNumberGenerator.next = 0x1337 - swf := &wire.StopWaitingFrame{LeastUnacked: 0x1337 - 0x100} - packer.QueueControlFrame(&wire.RstStreamFrame{}) - packer.QueueControlFrame(swf) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.frames[0].(*wire.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) - }) - - It("does not pack a packet containing only a STOP_WAITING frame", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() - expectAppendStreamFrames() - swf := &wire.StopWaitingFrame{LeastUnacked: 10} - packer.QueueControlFrame(swf) - p, err := packer.PackPacket() - Expect(p).To(BeNil()) - Expect(err).ToNot(HaveOccurred()) - }) - - It("packs a packet if it has queued control frames, but no new control frames", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() - expectAppendStreamFrames() - packer.controlFrames = []wire.Frame{&wire.BlockedFrame{}} - p, err := packer.PackPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(p).ToNot(BeNil()) + Expect(p.frames).To(HaveLen(2)) + Expect(p.frames[1].(*wire.StopWaitingFrame).PacketNumberLen).To(Equal(protocol.PacketNumberLen2)) }) It("refuses to send a packet that doesn't contain crypto stream data, if it has never sent a packet before", func() { @@ -435,7 +427,8 @@ var _ = Describe("Packet packer", func() { Expect(p).To(BeNil()) }) - It("packs many control frames into 1 packets", func() { + It("packs many control frames into one packets", func() { + mockAckFramer.EXPECT().GetAckFrame().Times(2) f := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 1, Smallest: 1}}} b := &bytes.Buffer{} err := f.Write(b, packer.version) @@ -446,15 +439,16 @@ var _ = Describe("Packet packer", func() { controlFrames = append(controlFrames, f) } packer.controlFrames = controlFrames - payloadFrames, err := packer.composeNextPacket(maxFrameSize, false) + payloadFrames, err := packer.composeNextPacket(nil, maxFrameSize, false) Expect(err).ToNot(HaveOccurred()) Expect(payloadFrames).To(HaveLen(maxFramesPerPacket)) - payloadFrames, err = packer.composeNextPacket(maxFrameSize, false) + payloadFrames, err = packer.composeNextPacket(nil, maxFrameSize, false) Expect(err).ToNot(HaveOccurred()) Expect(payloadFrames).To(BeEmpty()) }) It("packs a lot of control frames into 2 packets if they don't fit into one", func() { + mockAckFramer.EXPECT().GetAckFrame().Times(2) blockedFrame := &wire.BlockedFrame{} maxFramesPerPacket := int(maxFrameSize) / int(blockedFrame.Length(packer.version)) var controlFrames []wire.Frame @@ -462,15 +456,16 @@ var _ = Describe("Packet packer", func() { controlFrames = append(controlFrames, blockedFrame) } packer.controlFrames = controlFrames - payloadFrames, err := packer.composeNextPacket(maxFrameSize, false) + payloadFrames, err := packer.composeNextPacket(nil, maxFrameSize, false) Expect(err).ToNot(HaveOccurred()) Expect(payloadFrames).To(HaveLen(maxFramesPerPacket)) - payloadFrames, err = packer.composeNextPacket(maxFrameSize, false) + payloadFrames, err = packer.composeNextPacket(nil, maxFrameSize, false) Expect(err).ToNot(HaveOccurred()) Expect(payloadFrames).To(HaveLen(10)) }) It("only increases the packet number when there is an actual packet to send", func() { + mockAckFramer.EXPECT().GetAckFrame().Times(2) mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) expectAppendStreamFrames() packer.packetNumberGenerator.nextToSkip = 1000 @@ -493,7 +488,8 @@ var _ = Describe("Packet packer", func() { sendMaxNumNonRetransmittableAcks := func() { mockStreamFramer.EXPECT().HasCryptoStreamData().Times(protocol.MaxNonRetransmittableAcks) for i := 0; i < protocol.MaxNonRetransmittableAcks; i++ { - packer.QueueControlFrame(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + mockAckFramer.EXPECT().GetStopWaitingFrame(false) expectAppendStreamFrames() p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) @@ -504,15 +500,18 @@ var _ = Describe("Packet packer", func() { It("adds a PING frame when it's supposed to send a retransmittable packet", func() { sendMaxNumNonRetransmittableAcks() - mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + mockAckFramer.EXPECT().GetStopWaitingFrame(false) expectAppendStreamFrames() - packer.QueueControlFrame(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(ContainElement(&wire.PingFrame{})) // make sure the next packet doesn't contain another PING - packer.QueueControlFrame(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 2, Largest: 2}}}) + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + mockAckFramer.EXPECT().GetStopWaitingFrame(false) expectAppendStreamFrames() p, err = packer.PackPacket() Expect(p).ToNot(BeNil()) @@ -522,13 +521,18 @@ var _ = Describe("Packet packer", func() { It("waits until there's something to send before adding a PING frame", func() { sendMaxNumNonRetransmittableAcks() - mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + // nothing to send + mockStreamFramer.EXPECT().HasCryptoStreamData() expectAppendStreamFrames() + mockAckFramer.EXPECT().GetAckFrame() p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) + // now add some frame to send expectAppendStreamFrames() - packer.QueueControlFrame(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}}) + mockAckFramer.EXPECT().GetStopWaitingFrame(false) p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(2)) @@ -538,20 +542,20 @@ var _ = Describe("Packet packer", func() { It("doesn't send a PING if it already sent another retransmittable frame", func() { sendMaxNumNonRetransmittableAcks() mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() expectAppendStreamFrames() packer.QueueControlFrame(&wire.MaxDataFrame{}) - packer.QueueControlFrame(&wire.StopWaitingFrame{}) p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) - Expect(p.frames).To(HaveLen(2)) Expect(p.frames).ToNot(ContainElement(&wire.PingFrame{})) }) }) Context("STREAM frame handling", func() { It("does not split a STREAM frame with maximum size, for gQUIC frames", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() mockStreamFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(_ []wire.Frame, maxSize protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { f := &wire.StreamFrame{ Offset: 1, @@ -567,6 +571,9 @@ var _ = Describe("Packet packer", func() { Expect(p.frames).To(HaveLen(1)) Expect(p.raw).To(HaveLen(int(maxPacketSize))) Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) + // make sure there's nothing else to send + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) @@ -574,7 +581,8 @@ var _ = Describe("Packet packer", func() { It("does not split a STREAM frame with maximum size, for IETF draft style frame", func() { packer.version = versionIETFFrames - mockStreamFramer.EXPECT().HasCryptoStreamData().Times(2) + mockAckFramer.EXPECT().GetAckFrame() + mockStreamFramer.EXPECT().HasCryptoStreamData() mockStreamFramer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(_ []wire.Frame, maxSize protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) { f := &wire.StreamFrame{ Offset: 1, @@ -590,6 +598,9 @@ var _ = Describe("Packet packer", func() { Expect(p.frames).To(HaveLen(1)) Expect(p.raw).To(HaveLen(int(maxPacketSize))) Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) + // make sure there's nothing else to send + mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() p, err = packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) @@ -612,21 +623,23 @@ var _ = Describe("Packet packer", func() { DataLenPresent: true, } mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() expectAppendStreamFrames(f1, f2, f3) p, err := packer.PackPacket() Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(HaveLen(3)) Expect(p.frames[0].(*wire.StreamFrame).Data).To(Equal([]byte("frame 1"))) - Expect(p.frames[1].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) - Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 3"))) Expect(p.frames[0].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) + Expect(p.frames[1].(*wire.StreamFrame).Data).To(Equal([]byte("frame 2"))) Expect(p.frames[1].(*wire.StreamFrame).DataLenPresent).To(BeTrue()) + Expect(p.frames[2].(*wire.StreamFrame).Data).To(Equal([]byte("frame 3"))) Expect(p.frames[2].(*wire.StreamFrame).DataLenPresent).To(BeFalse()) }) It("refuses to send unencrypted stream data on a data stream", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() // don't expect a call to mockStreamFramer.PopStreamFrames packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted p, err := packer.PackPacket() @@ -636,6 +649,7 @@ var _ = Describe("Packet packer", func() { It("sends non forward-secure data as the client", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() f := &wire.StreamFrame{ StreamID: 5, Data: []byte("foobar"), @@ -651,6 +665,7 @@ var _ = Describe("Packet packer", func() { It("does not send non forward-secure data as the server", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() + mockAckFramer.EXPECT().GetAckFrame() // don't expect a call to mockStreamFramer.PopStreamFrames packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure p, err := packer.PackPacket() @@ -709,10 +724,11 @@ var _ = Describe("Packet packer", func() { It("does not pack STREAM frames if not allowed", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 10, Smallest: 1}}} + mockAckFramer.EXPECT().GetAckFrame().Return(ack) + mockAckFramer.EXPECT().GetStopWaitingFrame(false) // don't expect a call to mockStreamFramer.PopStreamFrames packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 10, Smallest: 1}}} - packer.QueueControlFrame(ack) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) Expect(p.frames).To(Equal([]wire.Frame{ack})) @@ -721,37 +737,25 @@ var _ = Describe("Packet packer", func() { It("packs a single ACK", func() { mockStreamFramer.EXPECT().HasCryptoStreamData() - expectAppendStreamFrames() ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Largest: 42, Smallest: 1}}} - packer.QueueControlFrame(ack) + mockAckFramer.EXPECT().GetAckFrame().Return(ack) + mockAckFramer.EXPECT().GetStopWaitingFrame(false) + expectAppendStreamFrames() p, err := packer.PackPacket() Expect(err).NotTo(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.frames[0]).To(Equal(ack)) }) - It("does not return nil if we only have a single ACK but request it to be sent", func() { - mockStreamFramer.EXPECT().HasCryptoStreamData() - expectAppendStreamFrames() - ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 1}}} - packer.QueueControlFrame(ack) - p, err := packer.PackPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(p).ToNot(BeNil()) - }) - Context("retransmitting of handshake packets", func() { - swf := &wire.StopWaitingFrame{LeastUnacked: 1} sf := &wire.StreamFrame{ StreamID: 1, Data: []byte("foobar"), } - BeforeEach(func() { - packer.QueueControlFrame(swf) - }) - It("packs a retransmission for a packet sent with no encryption", func() { + swf := &wire.StopWaitingFrame{LeastUnacked: 1} + mockAckFramer.EXPECT().GetStopWaitingFrame(true).Return(swf) packet := &ackhandler.Packet{ PacketType: protocol.PacketTypeHandshake, EncryptionLevel: protocol.EncryptionUnencrypted, @@ -778,7 +782,9 @@ var _ = Describe("Packet packer", func() { Expect(p[0].encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) }) - It("packs a retransmission for a packet sent with initial encryption", func() { + It("packs a retransmission for a packet sent with secure encryption", func() { + swf := &wire.StopWaitingFrame{LeastUnacked: 1} + mockAckFramer.EXPECT().GetStopWaitingFrame(true).Return(swf) packet := &ackhandler.Packet{ EncryptionLevel: protocol.EncryptionSecure, Frames: []wire.Frame{sf}, @@ -788,25 +794,16 @@ var _ = Describe("Packet packer", func() { Expect(p).To(HaveLen(1)) Expect(p[0].frames).To(Equal([]wire.Frame{swf, sf})) Expect(p[0].encryptionLevel).To(Equal(protocol.EncryptionSecure)) - // a packet sent by the server with initial encryption contains the SHLO + // a packet sent by the server with secure encryption contains the SHLO // it needs to have a diversification nonce Expect(p[0].raw).To(ContainSubstring(string(divNonce))) }) - It("includes the diversification nonce on packets sent with initial encryption", func() { - packet := &ackhandler.Packet{ - EncryptionLevel: protocol.EncryptionSecure, - Frames: []wire.Frame{sf}, - } - p, err := packer.PackRetransmission(packet) - Expect(err).ToNot(HaveOccurred()) - Expect(p).To(HaveLen(1)) - Expect(p[0].encryptionLevel).To(Equal(protocol.EncryptionSecure)) - }) - // this should never happen, since non forward-secure packets are limited to a size smaller than MaxPacketSize, such that it is always possible to retransmit them without splitting the StreamFrame // (note that the retransmitted packet needs to have enough space for the StopWaitingFrame) It("refuses to send a packet larger than MaxPacketSize", func() { + swf := &wire.StopWaitingFrame{LeastUnacked: 1} + mockAckFramer.EXPECT().GetStopWaitingFrame(true).Return(swf) packet := &ackhandler.Packet{ EncryptionLevel: protocol.EncryptionSecure, Frames: []wire.Frame{ @@ -873,23 +870,13 @@ var _ = Describe("Packet packer", func() { Expect(p[0].header.Type).To(Equal(protocol.PacketTypeInitial)) Expect(p[0].header.Token).To(Equal(token)) }) - - It("refuses to retransmit packets without a STOP_WAITING Frame", func() { - packer.stopWaiting = nil - _, err := packer.PackRetransmission(&ackhandler.Packet{ - EncryptionLevel: protocol.EncryptionSecure, - }) - Expect(err).To(MatchError("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame")) - }) }) Context("retransmission of forward-secure packets", func() { - BeforeEach(func() { - packer.packetNumberGenerator.next = 15 - packer.stopWaiting = &wire.StopWaitingFrame{LeastUnacked: 7} - }) - It("retransmits a small packet", func() { + swf := &wire.StopWaitingFrame{LeastUnacked: 7} + packer.packetNumberGenerator.next = 10 + mockAckFramer.EXPECT().GetStopWaitingFrame(true).Return(swf) frames := []wire.Frame{ &wire.MaxDataFrame{ByteOffset: 0x1234}, &wire.StreamFrame{StreamID: 42, Data: []byte("foobar")}, @@ -910,15 +897,6 @@ var _ = Describe("Packet packer", func() { Expect(p.frames[1:]).To(Equal(frames)) }) - It("refuses to retransmit packets without a STOP_WAITING Frame", func() { - packer.stopWaiting = nil - _, err := packer.PackRetransmission(&ackhandler.Packet{ - EncryptionLevel: protocol.EncryptionForwardSecure, - Frames: []wire.Frame{&wire.MaxDataFrame{ByteOffset: 0x1234}}, - }) - Expect(err).To(MatchError("PacketPacker BUG: Handshake retransmissions must contain a STOP_WAITING frame")) - }) - It("packs two packets for retransmission if the original packet contained many control frames", func() { var frames []wire.Frame var totalLen protocol.ByteCount @@ -928,6 +906,9 @@ var _ = Describe("Packet packer", func() { frames = append(frames, f) totalLen += f.Length(packer.version) } + packer.packetNumberGenerator.next = 10 + swf := &wire.StopWaitingFrame{LeastUnacked: 7} + mockAckFramer.EXPECT().GetStopWaitingFrame(true).Return(swf) packets, err := packer.PackRetransmission(&ackhandler.Packet{ EncryptionLevel: protocol.EncryptionForwardSecure, Frames: frames, @@ -945,6 +926,8 @@ var _ = Describe("Packet packer", func() { }) It("splits a STREAM frame that doesn't fit", func() { + swf := &wire.StopWaitingFrame{} + mockAckFramer.EXPECT().GetStopWaitingFrame(true).Return(swf) packets, err := packer.PackRetransmission(&ackhandler.Packet{ EncryptionLevel: protocol.EncryptionForwardSecure, Frames: []wire.Frame{&wire.StreamFrame{ @@ -972,6 +955,8 @@ var _ = Describe("Packet packer", func() { }) It("packs two packets for retransmission if the original packet contained many STREAM frames", func() { + swf := &wire.StopWaitingFrame{} + mockAckFramer.EXPECT().GetStopWaitingFrame(true).Return(swf) var frames []wire.Frame var totalLen protocol.ByteCount // pack a bunch of control frames, such that the packet is way bigger than a single packet @@ -1001,6 +986,8 @@ var _ = Describe("Packet packer", func() { }) It("correctly sets the DataLenPresent on STREAM frames", func() { + swf := &wire.StopWaitingFrame{} + mockAckFramer.EXPECT().GetStopWaitingFrame(true).Return(swf) frames := []wire.Frame{ &wire.StreamFrame{StreamID: 4, Data: []byte("foobar"), DataLenPresent: true}, &wire.StreamFrame{StreamID: 5, Data: []byte("barfoo")}, @@ -1025,29 +1012,36 @@ var _ = Describe("Packet packer", func() { }) Context("packing ACK packets", func() { - It("packs ACK packets", func() { - packer.QueueControlFrame(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}}) - p, err := packer.PackAckPacket() - Expect(err).NotTo(HaveOccurred()) - Expect(p.frames).To(HaveLen(1)) - Expect(p.frames[0]).To(BeAssignableToTypeOf(&wire.AckFrame{})) - ack := p.frames[0].(*wire.AckFrame) - Expect(ack.LargestAcked()).To(Equal(protocol.PacketNumber(10))) + It("doesn't pack a packet if there's no ACK to send", func() { + mockAckFramer.EXPECT().GetAckFrame() + p, err := packer.MaybePackAckPacket() + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) }) - It("packs ACK packets with STOP_WAITING frames", func() { - packer.QueueControlFrame(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}}) - packer.QueueControlFrame(&wire.StopWaitingFrame{}) - p, err := packer.PackAckPacket() + It("packs ACK packets", func() { + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} + swf := &wire.StopWaitingFrame{} + mockAckFramer.EXPECT().GetAckFrame().Return(ack) + mockAckFramer.EXPECT().GetStopWaitingFrame(false).Return(swf) + p, err := packer.MaybePackAckPacket() Expect(err).NotTo(HaveOccurred()) - Expect(p.frames).To(HaveLen(2)) - Expect(p.frames[0]).To(BeAssignableToTypeOf(&wire.AckFrame{})) - Expect(p.frames[1]).To(Equal(&wire.StopWaitingFrame{PacketNumber: 1, PacketNumberLen: 2})) + Expect(p.frames).To(Equal([]wire.Frame{ack, swf})) + }) + + It("doesn't add a STOP_WAITING frame for IETF QUIC", func() { + packer.version = versionIETFFrames + ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 10}}} + mockAckFramer.EXPECT().GetAckFrame().Return(ack) + p, err := packer.MaybePackAckPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(p.frames).To(Equal([]wire.Frame{ack})) }) }) Context("max packet size", func() { It("sets the maximum packet size", func() { + mockAckFramer.EXPECT().GetAckFrame().Times(2) for i := 0; i < 10*int(maxPacketSize); i++ { packer.QueueControlFrame(&wire.PingFrame{}) } @@ -1066,6 +1060,7 @@ var _ = Describe("Packet packer", func() { }) It("doesn't increase the max packet size", func() { + mockAckFramer.EXPECT().GetAckFrame().Times(2) for i := 0; i < 10*int(maxPacketSize); i++ { packer.QueueControlFrame(&wire.PingFrame{}) } diff --git a/session.go b/session.go index 478c1d28..b6087cf2 100644 --- a/session.go +++ b/session.go @@ -97,7 +97,7 @@ type session struct { connFlowController flowcontrol.ConnectionFlowController unpacker unpacker - packer *packetPacker + packer packer cryptoStreamHandler cryptoStreamHandler @@ -216,6 +216,7 @@ func newSession( divNonce, cs, s.streamFramer, + sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, s.perspective, s.version, ) @@ -289,6 +290,7 @@ var newClientSession = func( nil, // no diversification nonce cs, s.streamFramer, + sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, s.perspective, s.version, ) @@ -344,6 +346,7 @@ func newTLSServerSession( nil, // no diversification nonce cs, s.streamFramer, + sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, s.perspective, s.version, ) @@ -408,6 +411,7 @@ var newTLSClientSession = func( nil, // no diversification nonce cs, s.streamFramer, + sentAndReceivedPacketManager{s.sentPacketHandler, s.receivedPacketHandler}, s.perspective, s.version, ) @@ -417,6 +421,7 @@ var newTLSClientSession = func( func (s *session) preSetup() { s.rttStats = &congestion.RTTStats{} s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger, s.version) + s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version) s.connFlowController = flowcontrol.NewConnectionFlowController( protocol.ReceiveConnectionFlowControlWindow, protocol.ByteCount(s.config.MaxReceiveConnectionFlowControlWindow), @@ -439,7 +444,6 @@ func (s *session) postSetup() error { s.lastNetworkActivityTime = now s.sessionCreationTime = now - s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version) s.windowUpdateQueue = newWindowUpdateQueue(s.streamsMap, s.connFlowController, s.packer.QueueControlFrame) return nil } @@ -988,21 +992,13 @@ sendLoop: } func (s *session) maybeSendAckOnlyPacket() error { - ack := s.receivedPacketHandler.GetAckFrame() - if ack == nil { - return nil - } - s.packer.QueueControlFrame(ack) - - if s.version.UsesStopWaitingFrames() { // for gQUIC, maybe add a STOP_WAITING - if swf := s.sentPacketHandler.GetStopWaitingFrame(false); swf != nil { - s.packer.QueueControlFrame(swf) - } - } - packet, err := s.packer.PackAckPacket() + packet, err := s.packer.MaybePackAckPacket() if err != nil { return err } + if packet == nil { + return nil + } s.sentPacketHandler.SentPacket(packet.ToAckHandlerPacket()) return s.sendPackedPacket(packet) } @@ -1033,9 +1029,6 @@ func (s *session) maybeSendRetransmission() (bool, error) { s.logger.Debugf("Dequeueing retransmission for packet 0x%x", retransmitPacket.PacketNumber) } - if s.version.UsesStopWaitingFrames() { - s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true)) - } packets, err := s.packer.PackRetransmission(retransmitPacket) if err != nil { return false, err @@ -1060,9 +1053,6 @@ func (s *session) sendProbePacket() error { } s.logger.Debugf("Sending a retransmission for %#x as a probe packet.", p.PacketNumber) - if s.version.UsesStopWaitingFrames() { - s.packer.QueueControlFrame(s.sentPacketHandler.GetStopWaitingFrame(true)) - } packets, err := s.packer.PackRetransmission(p) if err != nil { return err @@ -1086,15 +1076,6 @@ func (s *session) sendPacket() (bool, error) { } s.windowUpdateQueue.QueueAll() - if ack := s.receivedPacketHandler.GetAckFrame(); ack != nil { - s.packer.QueueControlFrame(ack) - if s.version.UsesStopWaitingFrames() { - if swf := s.sentPacketHandler.GetStopWaitingFrame(false); swf != nil { - s.packer.QueueControlFrame(swf) - } - } - } - packet, err := s.packer.PackPacket() if err != nil || packet == nil { return false, err diff --git a/session_test.go b/session_test.go index 26cc4650..fa6fb390 100644 --- a/session_test.go +++ b/session_test.go @@ -73,6 +73,7 @@ var _ = Describe("Session", func() { mconn *mockConnection cryptoSetup *mockCryptoSetup streamManager *MockStreamManager + packer *MockPacker handshakeChan chan<- struct{} ) @@ -121,6 +122,8 @@ var _ = Describe("Session", func() { sess = pSess.(*session) streamManager = NewMockStreamManager(mockCtrl) sess.streamsMap = streamManager + packer = NewMockPacker(mockCtrl) + sess.packer = packer }) AfterEach(func() { @@ -442,11 +445,10 @@ var _ = Describe("Session", func() { }) It("handles PATH_CHALLENGE frames", func() { - err := sess.handleFrames([]wire.Frame{&wire.PathChallengeFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}}, protocol.EncryptionUnspecified) + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + packer.EXPECT().QueueControlFrame(&wire.PathResponseFrame{Data: data}) + err := sess.handleFrames([]wire.Frame{&wire.PathChallengeFrame{Data: data}}, protocol.EncryptionUnspecified) Expect(err).ToNot(HaveOccurred()) - Expect(sess.packer.controlFrames).To(HaveLen(1)) - Expect(sess.packer.controlFrames[0]).To(BeAssignableToTypeOf(&wire.PathResponseFrame{})) - Expect(sess.packer.controlFrames[0].(*wire.PathResponseFrame).Data).To(Equal([8]byte{1, 2, 3, 4, 5, 6, 7, 8})) }) It("handles BLOCKED frames", func() { @@ -515,21 +517,20 @@ var _ = Describe("Session", func() { It("shuts down without error", func() { streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, "")) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close() + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{raw: []byte("connection close")}, nil) + Expect(sess.Close()).To(Succeed()) Eventually(areSessionsRunning).Should(BeFalse()) Expect(mconn.written).To(HaveLen(1)) - buf := &bytes.Buffer{} - err := (&wire.ConnectionCloseFrame{ErrorCode: qerr.PeerGoingAway}).Write(buf, sess.version) - Expect(err).ToNot(HaveOccurred()) - Expect(mconn.written).To(Receive(ContainSubstring(buf.String()))) + Expect(mconn.written).To(Receive(ContainSubstring("connection close"))) Expect(sess.Context().Done()).To(BeClosed()) }) It("only closes once", func() { streamManager.EXPECT().CloseWithError(qerr.Error(qerr.PeerGoingAway, "")) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close() - sess.Close() + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + Expect(sess.Close()).To(Succeed()) + Expect(sess.Close()).To(Succeed()) Eventually(areSessionsRunning).Should(BeFalse()) Expect(mconn.written).To(HaveLen(1)) Expect(sess.Context().Done()).To(BeClosed()) @@ -539,6 +540,7 @@ var _ = Describe("Session", func() { testErr := errors.New("test error") streamManager.EXPECT().CloseWithError(qerr.Error(0x1337, testErr.Error())) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.CloseWithError(0x1337, testErr) Eventually(areSessionsRunning).Should(BeFalse()) Expect(sess.Context().Done()).To(BeClosed()) @@ -564,6 +566,7 @@ var _ = Describe("Session", func() { It("cancels the context when the run loop exists", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) returned := make(chan struct{}) go func() { defer GinkgoRecover() @@ -625,6 +628,7 @@ var _ = Describe("Session", func() { testErr := errors.New("unpack error") unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr) streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) hdr.PacketNumber = 5 done := make(chan struct{}) go func() { @@ -702,53 +706,55 @@ var _ = Describe("Session", func() { }) Context("sending packets", func() { - BeforeEach(func() { - sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends - }) + getPacket := func(pn protocol.PacketNumber) *packedPacket { + data := *getPacketBuffer() + data = append(data, []byte("foobar")...) + return &packedPacket{ + raw: data, + header: &wire.Header{PacketNumber: pn}, + } + } - It("sends ACK frames", func() { - packetNumber := protocol.PacketNumber(0x035e) - err := sess.receivedPacketHandler.ReceivedPacket(packetNumber, time.Now(), true) + It("sends packets", func() { + packer.EXPECT().PackPacket().Return(getPacket(1), nil) + err := sess.receivedPacketHandler.ReceivedPacket(0x035e, time.Now(), true) Expect(err).ToNot(HaveOccurred()) sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(sent).To(BeTrue()) - Expect(mconn.written).To(HaveLen(1)) - Expect(mconn.written).To(Receive(ContainSubstring(string([]byte{0x03, 0x5e})))) }) - It("adds MAX_STREAM_DATA frames", func() { - sess.windowUpdateQueue.callback(&wire.MaxStreamDataFrame{ - StreamID: 2, - ByteOffset: 20, - }) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.Frames).To(ContainElement(&wire.MaxStreamDataFrame{StreamID: 2, ByteOffset: 20})) - }) - sess.sentPacketHandler = sph + It("doesn't send packets if there's nothing to send", func() { + packer.EXPECT().PackPacket().Return(getPacket(2), nil) + err := sess.receivedPacketHandler.ReceivedPacket(0x035e, time.Now(), true) + Expect(err).ToNot(HaveOccurred()) sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(sent).To(BeTrue()) }) + It("sends ACK only packets", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetAlarmTimeout().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAck) + sph.EXPECT().ShouldSendNumPackets().Return(1000) + packer.EXPECT().MaybePackAckPacket() + sess.sentPacketHandler = sph + Expect(sess.sendPackets()).To(Succeed()) + }) + It("adds a BLOCKED frame when it is connection-level flow control blocked", func() { fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) + packer.EXPECT().PackPacket().Return(getPacket(1), nil) sess.connFlowController = fc - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.Frames).To(Equal([]wire.Frame{ - &wire.BlockedFrame{Offset: 1337}, - })) - }) - sess.sentPacketHandler = sph + packer.EXPECT().QueueControlFrame(&wire.BlockedFrame{Offset: 1337}) sent, err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(sent).To(BeTrue()) }) - It("sends public reset", func() { + It("sends PUBLIC_RESET", func() { err := sess.sendPublicReset(1) Expect(err).NotTo(HaveOccurred()) Expect(mconn.written).To(HaveLen(1)) @@ -782,57 +788,80 @@ var _ = Describe("Session", func() { }) It("sends a retransmission and a regular packet in the same run", func() { + packetToRetransmit := &ackhandler.Packet{ + PacketNumber: 10, + PacketType: protocol.PacketTypeHandshake, + } + retransmittedPacket := getPacket(123) + newPacket := getPacket(234) sess.windowUpdateQueue.callback(&wire.MaxDataFrame{}) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() - sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ - PacketNumber: 10, - PacketType: protocol.PacketTypeHandshake, - }) + sph.EXPECT().DequeuePacketForRetransmission().Return(packetToRetransmit) sph.EXPECT().SendMode().Return(ackhandler.SendRetransmission) sph.EXPECT().SendMode().Return(ackhandler.SendAny) sph.EXPECT().ShouldSendNumPackets().Return(2) sph.EXPECT().TimeUntilSend() - sph.EXPECT().GetStopWaitingFrame(gomock.Any()).Return(&wire.StopWaitingFrame{}) gomock.InOrder( + packer.EXPECT().PackRetransmission(packetToRetransmit).Return([]*packedPacket{retransmittedPacket}, nil), sph.EXPECT().SentPacketsAsRetransmission(gomock.Any(), protocol.PacketNumber(10)).Do(func(packets []*ackhandler.Packet, _ protocol.PacketNumber) { Expect(packets).To(HaveLen(1)) - Expect(len(packets[0].Frames)).To(BeNumerically(">", 0)) - Expect(packets[0].Frames[0]).To(BeAssignableToTypeOf(&wire.StopWaitingFrame{})) - Expect(packets[0].SendTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) + Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(123))) }), + packer.EXPECT().PackPacket().Return(newPacket, nil), sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.Frames).To(HaveLen(1)) - Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.MaxDataFrame{})) - Expect(p.SendTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) + Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(234))) }), ) sess.sentPacketHandler = sph - err := sess.sendPackets() - Expect(err).ToNot(HaveOccurred()) + Expect(sess.sendPackets()).To(Succeed()) + }) + + It("sends multiple packets, if the retransmission is split", func() { + sess.version = versionIETFFrames + packet := &ackhandler.Packet{ + PacketNumber: 42, + Frames: []wire.Frame{&wire.StreamFrame{ + StreamID: 0x5, + Data: []byte("foobar"), + }}, + EncryptionLevel: protocol.EncryptionForwardSecure, + } + retransmissions := []*packedPacket{getPacket(1337), getPacket(1338)} + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().DequeuePacketForRetransmission().Return(packet) + packer.EXPECT().PackRetransmission(packet).Return(retransmissions, nil) + sph.EXPECT().SentPacketsAsRetransmission(gomock.Any(), protocol.PacketNumber(42)).Do(func(packets []*ackhandler.Packet, _ protocol.PacketNumber) { + Expect(packets).To(HaveLen(2)) + Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(1337))) + Expect(packets[1].PacketNumber).To(Equal(protocol.PacketNumber(1338))) + }) + sess.sentPacketHandler = sph + sent, err := sess.maybeSendRetransmission() + Expect(err).NotTo(HaveOccurred()) + Expect(sent).To(BeTrue()) + Expect(mconn.written).To(HaveLen(2)) }) It("sends a probe packet", func() { - f := &wire.MaxDataFrame{ByteOffset: 1337} + packetToRetransmit := &ackhandler.Packet{ + PacketNumber: 0x42, + PacketType: protocol.PacketTypeHandshake, + } + retransmittedPacket := getPacket(123) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sph.EXPECT().TimeUntilSend() sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendTLP) sph.EXPECT().ShouldSendNumPackets().Return(1) - sph.EXPECT().DequeueProbePacket().Return(&ackhandler.Packet{ - PacketNumber: 0x42, - Frames: []wire.Frame{f}, - }, nil) - sph.EXPECT().GetStopWaitingFrame(true).Return(&wire.StopWaitingFrame{}) + sph.EXPECT().DequeueProbePacket().Return(packetToRetransmit, nil) + packer.EXPECT().PackRetransmission(packetToRetransmit).Return([]*packedPacket{retransmittedPacket}, nil) sph.EXPECT().SentPacketsAsRetransmission(gomock.Any(), protocol.PacketNumber(0x42)).Do(func(packets []*ackhandler.Packet, _ protocol.PacketNumber) { Expect(packets).To(HaveLen(1)) - p := packets[0] - Expect(p.Frames).To(HaveLen(2)) - Expect(p.Frames[1]).To(Equal(f)) + Expect(packets[0].PacketNumber).To(Equal(protocol.PacketNumber(123))) }) sess.sentPacketHandler = sph - err := sess.sendPackets() - Expect(err).ToNot(HaveOccurred()) + Expect(sess.sendPackets()).To(Succeed()) }) It("doesn't send when the SentPacketHandler doesn't allow it", func() { @@ -842,420 +871,197 @@ var _ = Describe("Session", func() { err := sess.sendPackets() Expect(err).ToNot(HaveOccurred()) }) - }) - Context("packet pacing", func() { - var sph *mockackhandler.MockSentPacketHandler + Context("packet pacing", func() { + var sph *mockackhandler.MockSentPacketHandler - BeforeEach(func() { - sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetAlarmTimeout().AnyTimes() - sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() - sph.EXPECT().DequeuePacketForRetransmission().AnyTimes() - sess.sentPacketHandler = sph - sess.packer.hasSentPacket = true - streamManager.EXPECT().CloseWithError(gomock.Any()) - }) - - It("sends multiple packets one by one immediately", func() { - sph.EXPECT().SentPacket(gomock.Any()).Times(2) - sph.EXPECT().ShouldSendNumPackets().Return(1).Times(2) - sph.EXPECT().TimeUntilSend().Return(time.Now()).Times(2) - sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) - sph.EXPECT().SendMode().Return(ackhandler.SendAny).Do(func() { - // make sure there's something to send - sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 1}) - }).Times(2) // allow 2 packets... - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - sess.run() - close(done) - }() - sess.scheduleSending() - Eventually(mconn.written).Should(HaveLen(2)) - Consistently(mconn.written).Should(HaveLen(2)) - // make the go routine return - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close() - Eventually(done).Should(BeClosed()) - }) - - // when becoming congestion limited, at some point the SendMode will change from SendAny to SendAck - // we shouldn't send the ACK in the same run - It("doesn't send an ACK right after becoming congestion limited", func() { - sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 1}) - sph.EXPECT().SentPacket(gomock.Any()) - sph.EXPECT().ShouldSendNumPackets().Return(1000) - sph.EXPECT().TimeUntilSend().Return(time.Now()) - sph.EXPECT().SendMode().Return(ackhandler.SendAny) - sph.EXPECT().SendMode().Return(ackhandler.SendAck) - rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) - rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour)).Times(2) - rph.EXPECT().GetAckFrame() - sess.receivedPacketHandler = rph - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - sess.run() - close(done) - }() - sess.scheduleSending() - Eventually(mconn.written).Should(HaveLen(1)) - Consistently(mconn.written).Should(HaveLen(1)) - // make the go routine return - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close() - Eventually(done).Should(BeClosed()) - }) - - It("paces packets", func() { - pacingDelay := scaleDuration(100 * time.Millisecond) - sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 1}) - sph.EXPECT().SentPacket(gomock.Any()).Times(2) - sph.EXPECT().TimeUntilSend().Return(time.Now().Add(-time.Minute)) // send one packet immediately - sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)) // send one - sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) - sph.EXPECT().ShouldSendNumPackets().Times(2).Return(1) - sph.EXPECT().SendMode().Return(ackhandler.SendAny).Do(func() { // after sending the first packet - // make sure there's something to send - sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 2}) - }).AnyTimes() - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - sess.run() - close(done) - }() - sess.scheduleSending() - Eventually(mconn.written).Should(HaveLen(1)) - Consistently(mconn.written, pacingDelay/2).Should(HaveLen(1)) - Eventually(mconn.written, 2*pacingDelay).Should(HaveLen(2)) - // make the go routine return - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close() - Eventually(done).Should(BeClosed()) - }) - - It("sends multiple packets at once", func() { - sph.EXPECT().SentPacket(gomock.Any()).Times(3) - sph.EXPECT().ShouldSendNumPackets().Return(3) - sph.EXPECT().TimeUntilSend().Return(time.Now()) - sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) - sph.EXPECT().SendMode().Return(ackhandler.SendAny).Do(func() { - // make sure there's something to send - sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 1}) - }).Times(3) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - sess.run() - close(done) - }() - sess.scheduleSending() - Eventually(mconn.written).Should(HaveLen(3)) - // make the go routine return - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close() - Eventually(done).Should(BeClosed()) - }) - - It("doesn't set a pacing timer when there is no data to send", func() { - sph.EXPECT().TimeUntilSend().Return(time.Now()) - sph.EXPECT().ShouldSendNumPackets().Return(1) - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - sess.run() - close(done) - }() - sess.scheduleSending() // no packet will get sent - Consistently(mconn.written).ShouldNot(Receive()) - // queue a frame, and expect that it won't be sent - sess.packer.QueueControlFrame(&wire.MaxDataFrame{ByteOffset: 1}) - Consistently(mconn.written).ShouldNot(Receive()) - // make the go routine return - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - sess.Close() - Eventually(done).Should(BeClosed()) - }) - }) - - Context("sending ACK only packets", func() { - It("doesn't do anything if there's no ACK to be sent", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sess.sentPacketHandler = sph - err := sess.maybeSendAckOnlyPacket() - Expect(err).ToNot(HaveOccurred()) - Expect(mconn.written).To(BeEmpty()) - }) - - It("sends ACK only packets", func() { - swf := &wire.StopWaitingFrame{LeastUnacked: 10} - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() - sph.EXPECT().GetAlarmTimeout().AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAck) - sph.EXPECT().ShouldSendNumPackets().Return(1000) - sph.EXPECT().GetStopWaitingFrame(false).Return(swf) - sph.EXPECT().TimeUntilSend() - sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.Frames).To(HaveLen(2)) - Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.AckFrame{})) - Expect(p.Frames[1]).To(Equal(swf)) - Expect(p.SendTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) - }) - sess.sentPacketHandler = sph - sess.packer.packetNumberGenerator.next = 0x1338 - sess.receivedPacketHandler.ReceivedPacket(1, time.Now(), true) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - sess.run() - close(done) - }() - sess.scheduleSending() - Eventually(mconn.written).Should(HaveLen(1)) - // make sure that the go routine returns - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - streamManager.EXPECT().CloseWithError(gomock.Any()) - sess.Close() - Eventually(done).Should(BeClosed()) - }) - - It("doesn't include a STOP_WAITING for an ACK-only packet for IETF QUIC", func() { - sess.version = versionIETFFrames - done := make(chan struct{}) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() - sph.EXPECT().GetAlarmTimeout().AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAck) - sph.EXPECT().ShouldSendNumPackets().Return(1000) - sph.EXPECT().TimeUntilSend() - sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.Frames).To(HaveLen(1)) - Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.AckFrame{})) - Expect(p.SendTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) - }) - sess.sentPacketHandler = sph - sess.packer.packetNumberGenerator.next = 0x1338 - sess.receivedPacketHandler.ReceivedPacket(1, time.Now(), true) - go func() { - defer GinkgoRecover() - sess.run() - close(done) - }() - sess.scheduleSending() - Eventually(mconn.written).Should(HaveLen(1)) - // make sure that the go routine returns - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - streamManager.EXPECT().CloseWithError(gomock.Any()) - sess.Close() - Eventually(done).Should(BeClosed()) - }) - }) - - Context("retransmissions", func() { - var sph *mockackhandler.MockSentPacketHandler - BeforeEach(func() { - // a STOP_WAITING frame is added, so make sure the packet number of the new package is higher than the packet number of the retransmitted packet - sess.packer.packetNumberGenerator.next = 0x1337 + 10 - sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends - sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() - sess.sentPacketHandler = sph - sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} - }) - - Context("for handshake packets", func() { - It("retransmits an unencrypted packet, and adds a STOP_WAITING frame (for gQUIC)", func() { - sf := &wire.StreamFrame{StreamID: 1, Data: []byte("foobar")} - swf := &wire.StopWaitingFrame{LeastUnacked: 0x1337} - sph.EXPECT().GetStopWaitingFrame(true).Return(swf) - sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ - PacketNumber: 42, - Frames: []wire.Frame{sf}, - EncryptionLevel: protocol.EncryptionUnencrypted, - }) - sph.EXPECT().SentPacketsAsRetransmission(gomock.Any(), protocol.PacketNumber(42)).Do(func(packets []*ackhandler.Packet, _ protocol.PacketNumber) { - Expect(packets).To(HaveLen(1)) - p := packets[0] - Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) - Expect(p.Frames).To(Equal([]wire.Frame{swf, sf})) - Expect(p.SendTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) - }) - sent, err := sess.maybeSendRetransmission() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) - Expect(mconn.written).To(HaveLen(1)) + BeforeEach(func() { + sph = mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetAlarmTimeout().AnyTimes() + sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() + sph.EXPECT().DequeuePacketForRetransmission().AnyTimes() + sess.sentPacketHandler = sph + streamManager.EXPECT().CloseWithError(gomock.Any()) }) - It("retransmits an unencrypted packet, and doesn't add a STOP_WAITING frame (for IETF QUIC)", func() { - sess.version = versionIETFFrames - sess.packer.version = versionIETFFrames - sess.packer.srcConnID = sess.destConnID - sf := &wire.StreamFrame{StreamID: 1, Data: []byte("foobar")} - sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ - PacketNumber: 1337, - Frames: []wire.Frame{sf}, - EncryptionLevel: protocol.EncryptionUnencrypted, - }) - sph.EXPECT().SentPacketsAsRetransmission(gomock.Any(), protocol.PacketNumber(1337)).Do(func(packets []*ackhandler.Packet, _ protocol.PacketNumber) { - Expect(packets).To(HaveLen(1)) - p := packets[0] - Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) - Expect(p.Frames).To(Equal([]wire.Frame{sf})) - Expect(p.SendTime).To(BeTemporally("~", time.Now(), 100*time.Millisecond)) - }) - sent, err := sess.maybeSendRetransmission() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) - Expect(mconn.written).To(HaveLen(1)) + It("sends multiple packets one by one immediately", func() { + sph.EXPECT().SentPacket(gomock.Any()).Times(2) + sph.EXPECT().ShouldSendNumPackets().Return(1).Times(2) + sph.EXPECT().TimeUntilSend().Return(time.Now()).Times(2) + sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(2) // allow 2 packets... + packer.EXPECT().PackPacket().Return(getPacket(10), nil) + packer.EXPECT().PackPacket().Return(getPacket(11), nil) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() + sess.scheduleSending() + Eventually(mconn.written).Should(HaveLen(2)) + Consistently(mconn.written).Should(HaveLen(2)) + // make the go routine return + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sess.Close() + Eventually(done).Should(BeClosed()) + }) + + // when becoming congestion limited, at some point the SendMode will change from SendAny to SendAck + // we shouldn't send the ACK in the same run + It("doesn't send an ACK right after becoming congestion limited", func() { + sph.EXPECT().SentPacket(gomock.Any()) + sph.EXPECT().ShouldSendNumPackets().Return(1000) + sph.EXPECT().TimeUntilSend().Return(time.Now()) + sph.EXPECT().SendMode().Return(ackhandler.SendAny) + sph.EXPECT().SendMode().Return(ackhandler.SendAck) + packer.EXPECT().PackPacket().Return(getPacket(100), nil) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() + sess.scheduleSending() + Eventually(mconn.written).Should(HaveLen(1)) + Consistently(mconn.written).Should(HaveLen(1)) + // make the go routine return + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sess.Close() + Eventually(done).Should(BeClosed()) + }) + + It("paces packets", func() { + pacingDelay := scaleDuration(100 * time.Millisecond) + sph.EXPECT().SentPacket(gomock.Any()).Times(2) + sph.EXPECT().TimeUntilSend().Return(time.Now().Add(-time.Minute)) // send one packet immediately + sph.EXPECT().TimeUntilSend().Return(time.Now().Add(pacingDelay)) // send one + sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) + sph.EXPECT().ShouldSendNumPackets().Times(2).Return(1) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + packer.EXPECT().PackPacket().Return(getPacket(100), nil) + packer.EXPECT().PackPacket().Return(getPacket(101), nil) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() + sess.scheduleSending() + Eventually(mconn.written).Should(HaveLen(1)) + Consistently(mconn.written, pacingDelay/2).Should(HaveLen(1)) + Eventually(mconn.written, 2*pacingDelay).Should(HaveLen(2)) + // make the go routine return + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sess.Close() + Eventually(done).Should(BeClosed()) + }) + + It("sends multiple packets at once", func() { + sph.EXPECT().SentPacket(gomock.Any()).Times(3) + sph.EXPECT().ShouldSendNumPackets().Return(3) + sph.EXPECT().TimeUntilSend().Return(time.Now()) + sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).Times(3) + packer.EXPECT().PackPacket().Return(getPacket(1000), nil) + packer.EXPECT().PackPacket().Return(getPacket(1001), nil) + packer.EXPECT().PackPacket().Return(getPacket(1002), nil) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() + sess.scheduleSending() + Eventually(mconn.written).Should(HaveLen(3)) + // make the go routine return + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sess.Close() + Eventually(done).Should(BeClosed()) + }) + + It("doesn't set a pacing timer when there is no data to send", func() { + sph.EXPECT().TimeUntilSend().Return(time.Now()) + sph.EXPECT().ShouldSendNumPackets().Return(1) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + packer.EXPECT().PackPacket() + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + sess.run() + close(done) + }() + sess.scheduleSending() // no packet will get sent + Consistently(mconn.written).ShouldNot(Receive()) + // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + sess.Close() + Eventually(done).Should(BeClosed()) }) }) - Context("for packets after the handshake", func() { - It("sends a STREAM frame from a packet queued for retransmission, and adds a STOP_WAITING (for gQUIC)", func() { - f := &wire.StreamFrame{ - StreamID: 0x5, - Data: []byte("foobar"), - } - swf := &wire.StopWaitingFrame{LeastUnacked: 10} - sph.EXPECT().GetStopWaitingFrame(true).Return(swf) - sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ - PacketNumber: 0x1337, - Frames: []wire.Frame{f}, - EncryptionLevel: protocol.EncryptionForwardSecure, - }) - sph.EXPECT().SentPacketsAsRetransmission(gomock.Any(), protocol.PacketNumber(0x1337)).Do(func(packets []*ackhandler.Packet, _ protocol.PacketNumber) { - Expect(packets).To(HaveLen(1)) - p := packets[0] - Expect(p.Frames).To(HaveLen(2)) - Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.StopWaitingFrame{})) - Expect(p.Frames[1]).To(Equal(f)) - Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) - }) - sent, err := sess.maybeSendRetransmission() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) - Expect(mconn.written).To(HaveLen(1)) + Context("scheduling sending", func() { + It("sends when scheduleSending is called", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().GetAlarmTimeout().AnyTimes() + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ShouldSendNumPackets().AnyTimes().Return(1) + sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() + sph.EXPECT().SentPacket(gomock.Any()) + sess.sentPacketHandler = sph + packer.EXPECT().PackPacket().Return(getPacket(1), nil) + + go func() { + defer GinkgoRecover() + sess.run() + }() + Consistently(mconn.written).ShouldNot(Receive()) + sess.scheduleSending() + Eventually(mconn.written).Should(Receive()) + // make the go routine return + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + sess.Close() + Eventually(sess.Context().Done()).Should(BeClosed()) }) - It("sends a STREAM frame from a packet queued for retransmission, and doesn't add a STOP_WAITING (for IETF QUIC)", func() { - sess.version = versionIETFFrames - sess.packer.version = versionIETFFrames - f := &wire.StreamFrame{ - StreamID: 0x5, - Data: []byte("foobar"), - } - sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ - PacketNumber: 42, - Frames: []wire.Frame{f}, - EncryptionLevel: protocol.EncryptionForwardSecure, + It("sets the timer to the ack timer", func() { + packer.EXPECT().PackPacket().Return(getPacket(1234), nil) + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().TimeUntilSend().Return(time.Now()) + sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) + sph.EXPECT().GetAlarmTimeout().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ShouldSendNumPackets().Return(1) + sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { + Expect(p.PacketNumber).To(Equal(protocol.PacketNumber(1234))) }) - sph.EXPECT().SentPacketsAsRetransmission(gomock.Any(), protocol.PacketNumber(42)).Do(func(packets []*ackhandler.Packet, _ protocol.PacketNumber) { - Expect(packets).To(HaveLen(1)) - p := packets[0] - Expect(p.Frames).To(Equal([]wire.Frame{f})) - Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) - }) - sent, err := sess.maybeSendRetransmission() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) - Expect(mconn.written).To(HaveLen(1)) + sess.sentPacketHandler = sph + rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) + rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond)) + // make the run loop wait + rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour)).MaxTimes(1) + sess.receivedPacketHandler = rph + + go func() { + defer GinkgoRecover() + sess.run() + }() + Eventually(mconn.written).Should(Receive()) + // make sure the go routine returns + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) + sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + streamManager.EXPECT().CloseWithError(gomock.Any()) + sess.Close() + Eventually(sess.Context().Done()).Should(BeClosed()) }) - - It("sends multiple packets, if the retransmission is split", func() { - sess.version = versionIETFFrames - sess.packer.version = versionIETFFrames - f := &wire.StreamFrame{ - StreamID: 0x5, - Data: bytes.Repeat([]byte{'b'}, int(protocol.MaxPacketSizeIPv4)*3/2), - } - sph.EXPECT().DequeuePacketForRetransmission().Return(&ackhandler.Packet{ - PacketNumber: 42, - Frames: []wire.Frame{f}, - EncryptionLevel: protocol.EncryptionForwardSecure, - }) - sph.EXPECT().SentPacketsAsRetransmission(gomock.Any(), protocol.PacketNumber(42)).Do(func(packets []*ackhandler.Packet, _ protocol.PacketNumber) { - Expect(packets).To(HaveLen(2)) - for _, p := range packets { - Expect(p.Frames).To(HaveLen(1)) - Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.StreamFrame{})) - Expect(p.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure)) - } - }) - sent, err := sess.maybeSendRetransmission() - Expect(err).NotTo(HaveOccurred()) - Expect(sent).To(BeTrue()) - Expect(mconn.written).To(HaveLen(2)) - }) - }) - }) - - Context("scheduling sending", func() { - BeforeEach(func() { - sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends - sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure} - }) - - It("sends when scheduleSending is called", func() { - sess.packer.packetNumberGenerator.next = 10000 - sess.packer.QueueControlFrame(&wire.BlockedFrame{}) - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().GetAlarmTimeout().AnyTimes() - sph.EXPECT().TimeUntilSend().AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().ShouldSendNumPackets().AnyTimes().Return(1) - sph.EXPECT().GetPacketNumberLen(gomock.Any()).Return(protocol.PacketNumberLen2).AnyTimes() - sph.EXPECT().SentPacket(gomock.Any()) - sess.sentPacketHandler = sph - - go func() { - defer GinkgoRecover() - sess.run() - }() - Consistently(mconn.written).ShouldNot(Receive()) - sess.scheduleSending() - Eventually(mconn.written).Should(Receive()) - // make the go routine return - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - streamManager.EXPECT().CloseWithError(gomock.Any()) - sess.Close() - Eventually(sess.Context().Done()).Should(BeClosed()) - }) - - It("sets the timer to the ack timer", func() { - sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) - sph.EXPECT().TimeUntilSend().Return(time.Now()) - sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)) - sph.EXPECT().GetAlarmTimeout().AnyTimes() - sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() - sph.EXPECT().GetStopWaitingFrame(gomock.Any()) - sph.EXPECT().ShouldSendNumPackets().Return(1) - sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) { - Expect(p.Frames[0]).To(BeAssignableToTypeOf(&wire.AckFrame{})) - Expect(p.Frames[0].(*wire.AckFrame).LargestAcked()).To(Equal(protocol.PacketNumber(0x1337))) - }) - sess.sentPacketHandler = sph - rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) - rph.EXPECT().GetAckFrame().Return(&wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 1, Largest: 0x1337}}}) - rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(10 * time.Millisecond)) - // make the run loop wait - rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour)).MaxTimes(1) - sess.receivedPacketHandler = rph - - go func() { - defer GinkgoRecover() - sess.run() - }() - Eventually(mconn.written).Should(Receive()) - // make sure the go routine returns - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) - streamManager.EXPECT().CloseWithError(gomock.Any()) - sess.Close() - Eventually(sess.Context().Done()).Should(BeClosed()) }) }) @@ -1263,6 +1069,7 @@ var _ = Describe("Session", func() { testErr := errors.New("crypto setup error") streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error())) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.handleErr = testErr go func() { defer GinkgoRecover() @@ -1272,7 +1079,7 @@ var _ = Describe("Session", func() { Eventually(sess.Context().Done()).Should(BeClosed()) }) - Context("sending a Public Reset when receiving undecryptable packets during the handshake", func() { + Context("sending a PUBLIC_RESET when receiving undecryptable packets during the handshake", func() { // sends protocol.MaxUndecryptablePackets+1 undecrytable packets // this completely fills up the undecryptable packets queue and triggers the public reset timer sendUndecryptablePackets := func() { @@ -1294,9 +1101,10 @@ var _ = Describe("Session", func() { sess.unpacker = unpacker sess.cryptoStreamHandler = &mockCryptoSetup{} streamManager.EXPECT().CloseWithError(gomock.Any()).MaxTimes(1) + packer.EXPECT().PackPacket().AnyTimes() }) - It("doesn't immediately send a Public Reset after receiving too many undecryptable packets", func() { + It("doesn't immediately send a PUBLIC_RESET after receiving too many undecryptable packets", func() { go func() { defer GinkgoRecover() sess.run() @@ -1306,11 +1114,12 @@ var _ = Describe("Session", func() { Consistently(mconn.written).Should(HaveLen(0)) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) - It("sets a deadline to send a Public Reset after receiving too many undecryptable packets", func() { + It("sets a deadline to send a PUBLIC_RESET after receiving too many undecryptable packets", func() { go func() { defer GinkgoRecover() sess.run() @@ -1319,6 +1128,7 @@ var _ = Describe("Session", func() { Eventually(func() time.Time { return sess.receivedTooManyUndecrytablePacketsTime }).Should(BeTemporally("~", time.Now(), 20*time.Millisecond)) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1334,11 +1144,12 @@ var _ = Describe("Session", func() { Expect(sess.undecryptablePackets[0].header.PacketNumber).To(Equal(protocol.PacketNumber(1))) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) - It("sends a Public Reset after a timeout", func() { + It("sends a PUBLIC_RESET after a timeout", func() { sessionRunner.EXPECT().removeConnectionID(gomock.Any()) Expect(sess.receivedTooManyUndecrytablePacketsTime).To(BeZero()) go func() { @@ -1356,7 +1167,7 @@ var _ = Describe("Session", func() { Eventually(sess.Context().Done()).Should(BeClosed()) }) - It("doesn't send a Public Reset if decrypting them succeeded during the timeout", func() { + It("doesn't send a PUBLIC_RESET if decrypting them succeeded during the timeout", func() { go func() { defer GinkgoRecover() sess.run() @@ -1369,6 +1180,7 @@ var _ = Describe("Session", func() { Expect(sess.Context().Done()).ToNot(Receive()) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1383,6 +1195,7 @@ var _ = Describe("Session", func() { Consistently(sess.undecryptablePackets).Should(BeEmpty()) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1408,21 +1221,23 @@ var _ = Describe("Session", func() { // make sure the go routine returns sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) It("calls the onHandshakeComplete callback when the handshake completes", func() { + close(handshakeChan) + sessionRunner.EXPECT().onHandshakeComplete(gomock.Any()) go func() { defer GinkgoRecover() sess.run() }() - sessionRunner.EXPECT().onHandshakeComplete(gomock.Any()) - close(handshakeChan) Consistently(sess.Context().Done()).ShouldNot(BeClosed()) // make sure the go routine returns sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1437,6 +1252,7 @@ var _ = Describe("Session", func() { }() streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) Expect(sess.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) @@ -1452,6 +1268,7 @@ var _ = Describe("Session", func() { }() streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) Expect(sess.CloseWithError(0x1337, testErr)).To(Succeed()) Eventually(done).Should(BeClosed()) }) @@ -1472,13 +1289,14 @@ var _ = Describe("Session", func() { MaxPacketSize: 0x42, } streamManager.EXPECT().UpdateLimits(¶ms) + packer.EXPECT().SetOmitConnectionID() + packer.EXPECT().SetMaxPacketSize(protocol.ByteCount(0x42)) paramsChan <- params Eventually(func() *handshake.TransportParameters { return sess.peerParams }).Should(Equal(¶ms)) - Eventually(func() bool { return sess.packer.omitConnectionID }).Should(BeTrue()) - Eventually(func() protocol.ByteCount { return sess.packer.maxPacketSize }).Should(Equal(protocol.ByteCount(0x42))) // make the go routine return streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) }) @@ -1496,20 +1314,23 @@ var _ = Describe("Session", func() { sess.handshakeComplete = true sess.config.KeepAlive = true sess.lastNetworkActivityTime = time.Now().Add(-remoteIdleTimeout / 2) - sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends + sent := make(chan struct{}) + packer.EXPECT().QueueControlFrame(&wire.PingFrame{}) + packer.EXPECT().PackPacket().Do(func() (*packedPacket, error) { + close(sent) + return nil, nil + }) done := make(chan struct{}) go func() { defer GinkgoRecover() sess.run() close(done) }() - var data []byte - Eventually(mconn.written).Should(Receive(&data)) - // -12 because of the crypto tag. This should be 7 (the frame id for a ping frame). - Expect(data[len(data)-12-1 : len(data)-12]).To(Equal([]byte{0x07})) + Eventually(sent).Should(BeClosed()) // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.Close() Eventually(done).Should(BeClosed()) }) @@ -1528,6 +1349,7 @@ var _ = Describe("Session", func() { // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.Close() Eventually(done).Should(BeClosed()) }) @@ -1546,6 +1368,7 @@ var _ = Describe("Session", func() { // make the go routine return sessionRunner.EXPECT().removeConnectionID(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.Close() Eventually(done).Should(BeClosed()) }) @@ -1561,6 +1384,10 @@ var _ = Describe("Session", func() { sess.handshakeComplete = true sess.lastNetworkActivityTime = time.Now().Add(-time.Hour) done := make(chan struct{}) + packer.EXPECT().PackConnectionClose(gomock.Any()).DoAndReturn(func(f *wire.ConnectionCloseFrame) (*packedPacket, error) { + Expect(f.ErrorCode).To(Equal(qerr.NetworkIdleTimeout)) + return &packedPacket{}, nil + }) go func() { defer GinkgoRecover() err := sess.run() @@ -1568,12 +1395,15 @@ var _ = Describe("Session", func() { close(done) }() Eventually(done).Should(BeClosed()) - Expect(mconn.written).To(Receive(ContainSubstring("No recent network activity."))) }) It("times out due to non-completed handshake", func() { sessionRunner.EXPECT().removeConnectionID(gomock.Any()) sess.sessionCreationTime = time.Now().Add(-protocol.DefaultHandshakeTimeout).Add(-time.Second) + packer.EXPECT().PackConnectionClose(gomock.Any()).DoAndReturn(func(f *wire.ConnectionCloseFrame) (*packedPacket, error) { + Expect(f.ErrorCode).To(Equal(qerr.HandshakeTimeout)) + return &packedPacket{}, nil + }) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -1582,13 +1412,16 @@ var _ = Describe("Session", func() { close(done) }() Eventually(done).Should(BeClosed()) - Expect(mconn.written).To(Receive(ContainSubstring("Crypto handshake did not complete in time."))) }) It("does not use the idle timeout before the handshake complete", func() { sess.config.IdleTimeout = 9999 * time.Second defer sess.Close() sess.lastNetworkActivityTime = time.Now().Add(-time.Minute) + packer.EXPECT().PackConnectionClose(gomock.Any()).DoAndReturn(func(f *wire.ConnectionCloseFrame) (*packedPacket, error) { + Expect(f.ErrorCode).To(Equal(qerr.PeerGoingAway)) + return &packedPacket{}, nil + }) // the handshake timeout is irrelevant here, since it depends on the time the session was created, // and not on the last network activity go func() { @@ -1605,6 +1438,10 @@ var _ = Describe("Session", func() { It("closes the session due to the idle timeout after handshake", func() { sessionRunner.EXPECT().onHandshakeComplete(sess) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).DoAndReturn(func(f *wire.ConnectionCloseFrame) (*packedPacket, error) { + Expect(f.ErrorCode).To(Equal(qerr.NetworkIdleTimeout)) + return &packedPacket{}, nil + }) sess.config.IdleTimeout = 0 close(handshakeChan) done := make(chan struct{}) @@ -1615,7 +1452,6 @@ var _ = Describe("Session", func() { close(done) }() Eventually(done).Should(BeClosed()) - Expect(mconn.written).To(Receive(ContainSubstring("No recent network activity."))) }) }) @@ -1718,6 +1554,7 @@ var _ = Describe("Client Session", func() { var ( sess *session sessionRunner *MockSessionRunner + packer *MockPacker mconn *mockConnection handshakeChan chan<- struct{} @@ -1762,6 +1599,8 @@ var _ = Describe("Client Session", func() { ) sess = sessP.(*session) Expect(err).ToNot(HaveOccurred()) + packer = NewMockPacker(mockCtrl) + sess.packer = packer }) AfterEach(func() { @@ -1769,24 +1608,31 @@ var _ = Describe("Client Session", func() { }) It("sends a forward-secure packet when the handshake completes", func() { - sessionRunner.EXPECT().onHandshakeComplete(gomock.Any()) - sess.packer.hasSentPacket = true + done := make(chan struct{}) + gomock.InOrder( + sessionRunner.EXPECT().onHandshakeComplete(gomock.Any()), + packer.EXPECT().QueueControlFrame(&wire.PingFrame{}), + packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) { + close(done) + return &packedPacket{header: &wire.Header{}, raw: *getPacketBuffer()}, nil + }), + packer.EXPECT().PackPacket().AnyTimes(), + ) + close(handshakeChan) go func() { defer GinkgoRecover() sess.run() }() - close(handshakeChan) - Eventually(mconn.written).Should(Receive()) + Eventually(done).Should(BeClosed()) //make sure the go routine returns sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) }) It("changes the connection ID when receiving the first packet from the server", func() { sess.version = protocol.VersionTLS - sess.packer.version = protocol.VersionTLS - sess.packer.srcConnID = sess.destConnID unpacker := NewMockUnpacker(mockCtrl) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil) sess.unpacker = unpacker @@ -1794,25 +1640,20 @@ var _ = Describe("Client Session", func() { defer GinkgoRecover() sess.run() }() + newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7} + packer.EXPECT().ChangeDestConnectionID(newConnID) err := sess.handlePacketImpl(&receivedPacket{ header: &wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, - SrcConnectionID: protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7}, + SrcConnectionID: newConnID, DestConnectionID: sess.srcConnID, }, data: []byte{0}, }) Expect(err).ToNot(HaveOccurred()) - // the session should have changed the dest connection ID now - sess.packer.hasSentPacket = true - sess.queueControlFrame(&wire.PingFrame{}) - var packet []byte - Eventually(mconn.written).Should(Receive(&packet)) - hdr, err := wire.ParseInvariantHeader(bytes.NewReader(packet), 0) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.DestConnectionID).To(Equal(protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7})) // make sure the go routine returns + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) @@ -1841,6 +1682,7 @@ var _ = Describe("Client Session", func() { Expect(err).ToNot(HaveOccurred()) Expect(cryptoSetup.divNonce).To(Equal(hdr.DiversificationNonce)) // make the go routine return + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sessionRunner.EXPECT().removeConnectionID(gomock.Any()) Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed())