diff --git a/session.go b/session.go index 67d44654..68eb8fab 100644 --- a/session.go +++ b/session.go @@ -61,6 +61,15 @@ type receivedPacket struct { buffer *packetBuffer } +func (p *receivedPacket) Clone() *receivedPacket { + return &receivedPacket{ + remoteAddr: p.remoteAddr, + rcvTime: p.rcvTime, + data: p.data, + buffer: p.buffer, + } +} + type closeError struct { err error remote bool @@ -487,11 +496,18 @@ func (s *session) handleHandshakeComplete() { } } -func (s *session) handlePacketImpl(p *receivedPacket) bool { +func (s *session) handlePacketImpl(rp *receivedPacket) bool { var counter uint8 var lastConnID protocol.ConnectionID var processed bool - for len(p.data) > 0 { + data := rp.data + p := rp + for len(data) > 0 { + if counter > 0 { + p = p.Clone() + p.data = data + } + hdr, packetData, rest, err := wire.ParsePacket(p.data, s.srcConnID.Len()) if err != nil { s.logger.Debugf("error parsing packet: %s", err) @@ -514,11 +530,10 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool { s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest)) } p.data = packetData - pr := s.handleSinglePacket(p, hdr) - if pr { - processed = pr + if wasProcessed := s.handleSinglePacket(p, hdr); wasProcessed { + processed = true } - p.data = rest + data = rest } p.buffer.MaybeRelease() return processed @@ -744,6 +759,7 @@ func (s *session) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.E if err != nil { return err } + s.logger.Debugf("Handled crypto frame at level %s. encLevelChanged: %t", encLevel, encLevelChanged) if encLevelChanged { s.tryDecryptingQueuedPackets() } diff --git a/session_test.go b/session_test.go index af5e6be2..9e17a1e9 100644 --- a/session_test.go +++ b/session_test.go @@ -632,6 +632,25 @@ var _ = Describe("Session", func() { Expect(sess.handlePacketImpl(getPacket(hdr2, nil))).To(BeFalse()) }) + It("queues undecryptable packets", func() { + hdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: sess.destConnID, + SrcConnectionID: sess.srcConnID, + Length: 1, + Version: sess.version, + }, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 1, + } + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrOpenerNotYetAvailable) + packet := getPacket(hdr, nil) + Expect(sess.handlePacketImpl(packet)).To(BeFalse()) + Expect(sess.undecryptablePackets).To(Equal([]*receivedPacket{packet})) + }) + Context("updating the remote address", func() { It("doesn't support connection migration", func() { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ @@ -705,6 +724,26 @@ var _ = Describe("Session", func() { Expect(sess.handlePacketImpl(packet1)).To(BeTrue()) }) + It("works with undecryptable packets", func() { + hdrLen1, packet1 := getPacketWithLength(sess.srcConnID, 456) + hdrLen2, packet2 := getPacketWithLength(sess.srcConnID, 123) + gomock.InOrder( + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrOpenerNotYetAvailable), + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) { + Expect(data).To(HaveLen(hdrLen2 + 123 - 3)) + return &unpackedPacket{ + encryptionLevel: protocol.EncryptionHandshake, + data: []byte{0}, + }, nil + }), + ) + packet1.data = append(packet1.data, packet2.data...) + Expect(sess.handlePacketImpl(packet1)).To(BeTrue()) + + Expect(sess.undecryptablePackets).To(HaveLen(1)) + Expect(sess.undecryptablePackets[0].data).To(HaveLen(hdrLen1 + 456 - 3)) + }) + It("ignores coalesced packet parts if the destination connection IDs don't match", func() { wrongConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} Expect(sess.srcConnID).ToNot(Equal(wrongConnID))