diff --git a/session.go b/session.go index 63fd72cf..eb104299 100644 --- a/session.go +++ b/session.go @@ -596,24 +596,26 @@ runLoop: break runLoop default: } - // Now process all packets in the receivedPackets channel. - // Limit the number of packets to the length of the receivedPackets channel, - // so we eventually get a chance to send out an ACK when receiving a lot of packets. - numPackets := len(s.receivedPackets) - receiveLoop: - for i := 0; i < numPackets; i++ { - select { - case p := <-s.receivedPackets: - if processed := s.handlePacketImpl(p); processed { - wasProcessed = true - } + if s.handshakeComplete { + // Now process all packets in the receivedPackets channel. + // Limit the number of packets to the length of the receivedPackets channel, + // so we eventually get a chance to send out an ACK when receiving a lot of packets. + numPackets := len(s.receivedPackets) + receiveLoop: + for i := 0; i < numPackets; i++ { select { - case closeErr = <-s.closeChan: - break runLoop + case p := <-s.receivedPackets: + if processed := s.handlePacketImpl(p); processed { + wasProcessed = true + } + select { + case closeErr = <-s.closeChan: + break runLoop + default: + } default: + break receiveLoop } - default: - break receiveLoop } } // Only reset the timers if this packet was actually processed. diff --git a/session_test.go b/session_test.go index 859d276e..5e8cb97b 100644 --- a/session_test.go +++ b/session_test.go @@ -902,6 +902,51 @@ var _ = Describe("Session", func() { Eventually(sess.Context().Done()).Should(BeClosed()) }) + It("doesn't processes multiple received packets before sending one before handshake completion", func() { + sess.handshakeComplete = false + sess.sessionCreationTime = time.Now() + var pn protocol.PacketNumber + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { + pn++ + return &unpackedPacket{ + data: []byte{0}, // PADDING frame + encryptionLevel: protocol.Encryption1RTT, + packetNumber: pn, + hdr: &wire.ExtendedHeader{Header: *hdr}, + }, nil + }).Times(3) + tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ []logging.Frame) { + }).Times(3) + packer.EXPECT().PackCoalescedPacket().Times(3) // only expect a single call + + for i := 0; i < 3; i++ { + sess.handlePacket(getPacket(&wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: srcConnID}, + PacketNumber: 0x1337, + PacketNumberLen: protocol.PacketNumberLen2, + }, []byte("foobar"))) + } + + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + sess.run() + }() + Consistently(sess.Context().Done()).ShouldNot(BeClosed()) + + // make the go routine return + streamManager.EXPECT().CloseWithError(gomock.Any()) + cryptoSetup.EXPECT().Close() + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + expectReplaceWithClosed() + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + sess.closeLocal(errors.New("close")) + Eventually(sess.Context().Done()).Should(BeClosed()) + }) + It("closes the session when unpacking fails because the reserved bits were incorrect", func() { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, wire.ErrInvalidReservedBits) streamManager.EXPECT().CloseWithError(gomock.Any())