diff --git a/session.go b/session.go index 81147aab..eef65d2e 100644 --- a/session.go +++ b/session.go @@ -583,20 +583,41 @@ runLoop: case <-s.sendingScheduled: // We do all the interesting stuff after the switch statement, so // nothing to see here. - case p := <-s.receivedPackets: - s.sentPacketHandler.ReceivedBytes(p.Size()) - // Only reset the timers if this packet was actually processed. - // This avoids modifying any state when handling undecryptable packets, - // which could be injected by an attacker. - if wasProcessed := s.handlePacketImpl(p); !wasProcessed { - continue - } + case firstPacket := <-s.receivedPackets: + s.sentPacketHandler.ReceivedBytes(firstPacket.Size()) + wasProcessed := s.handlePacketImpl(firstPacket) // Don't set timers and send packets if the packet made us close the session. select { case closeErr = <-s.closeChan: break runLoop default: } + // Now process all packets in the receivedPackets channel. + // Limit the number of packets to the length of the receivedPackets channel, + // so we eventually get a chance to send out an ACK when receiving a lot of packets. + numPackets := len(s.receivedPackets) + receiveLoop: + for i := 0; i < numPackets; i++ { + select { + case p := <-s.receivedPackets: + if processed := s.handlePacketImpl(p); processed { + wasProcessed = true + } + select { + case closeErr = <-s.closeChan: + break runLoop + default: + } + default: + break receiveLoop + } + } + // Only reset the timers if this packet was actually processed. + // This avoids modifying any state when handling undecryptable packets, + // which could be injected by an attacker. + if !wasProcessed { + continue + } case <-s.handshakeCompleteChan: s.handleHandshakeComplete() } diff --git a/session_test.go b/session_test.go index 2b153837..b2a30487 100644 --- a/session_test.go +++ b/session_test.go @@ -671,8 +671,9 @@ var _ = Describe("Session", func() { buf := &bytes.Buffer{} Expect(extHdr.Write(buf, sess.version)).To(Succeed()) return &receivedPacket{ - data: append(buf.Bytes(), data...), - buffer: getPacketBuffer(), + data: append(buf.Bytes(), data...), + buffer: getPacketBuffer(), + rcvTime: time.Now(), } } @@ -857,6 +858,51 @@ var _ = Describe("Session", func() { Eventually(sess.Context().Done()).Should(BeClosed()) }) + It("processes multiple received packets before sending one", func() { + sess.sessionCreationTime = time.Now() + var pn protocol.PacketNumber + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { + pn++ + return &unpackedPacket{ + data: []byte{0}, // PADDING frame + encryptionLevel: protocol.Encryption1RTT, + packetNumber: pn, + hdr: &wire.ExtendedHeader{Header: *hdr}, + }, nil + }).Times(3) + tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + tracer.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(hdr *wire.ExtendedHeader, _ protocol.ByteCount, _ []logging.Frame) { + }).Times(3) + packer.EXPECT().PackCoalescedPacket() // only expect a single call + + for i := 0; i < 3; i++ { + sess.handlePacket(getPacket(&wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: srcConnID}, + PacketNumber: 0x1337, + PacketNumberLen: protocol.PacketNumberLen2, + }, []byte("foobar"))) + } + + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + sess.run() + }() + Consistently(sess.Context().Done()).ShouldNot(BeClosed()) + + // make the go routine return + streamManager.EXPECT().CloseWithError(gomock.Any()) + cryptoSetup.EXPECT().Close() + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + expectReplaceWithClosed() + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + mconn.EXPECT().Write(gomock.Any()) + fmt.Println("close") + sess.closeLocal(errors.New("close")) + Eventually(sess.Context().Done()).Should(BeClosed()) + }) + It("closes the session when unpacking fails because the reserved bits were incorrect", func() { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, wire.ErrInvalidReservedBits) streamManager.EXPECT().CloseWithError(gomock.Any())