diff --git a/session.go b/session.go index 077b57299..4ad24df72 100644 --- a/session.go +++ b/session.go @@ -104,7 +104,11 @@ var errCloseForRecreating = errors.New("closing session in order to recreate it" // A Session is a QUIC session type session struct { - origDestConnID protocol.ConnectionID // if the server sends a Retry, this is the connection ID we used initially + // Destination connection ID used during the handshake. + // Used to check source connection ID on incoming packets. + handshakeDestConnID protocol.ConnectionID + // if the server sends a Retry, this is the connection ID we used initially + origDestConnID protocol.ConnectionID srcConnIDLen int perspective protocol.Perspective @@ -201,6 +205,7 @@ var newSession = func( s := &session{ conn: conn, config: conf, + handshakeDestConnID: destConnID, srcConnIDLen: srcConnID.Len(), tokenGenerator: tokenGenerator, perspective: protocol.PerspectiveServer, @@ -302,6 +307,7 @@ var newClientSession = func( s := &session{ conn: conn, config: conf, + handshakeDestConnID: destConnID, srcConnIDLen: srcConnID.Len(), perspective: protocol.PerspectiveClient, handshakeCompleteChan: make(chan struct{}), @@ -663,9 +669,8 @@ func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool / // The server can change the source connection ID with the first Handshake packet. // After this, all packets with a different source connection have to be ignored. - destConnID := s.connIDManager.Get() - if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(destConnID) { - s.logger.Debugf("Dropping %s packet with unexpected source connection ID: %s (expected %s)", hdr.PacketType(), hdr.SrcConnectionID, destConnID) + if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + s.logger.Debugf("Dropping %s packet with unexpected source connection ID: %s (expected %s)", hdr.PacketType(), hdr.SrcConnectionID, s.handshakeDestConnID) return false } // drop 0-RTT packets @@ -715,12 +720,11 @@ func (s *session) handleRetryPacket(hdr *wire.Header) bool /* was this a valid R return false } (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) - destConnID := s.connIDManager.Get() - if !hdr.OrigDestConnectionID.Equal(destConnID) { - s.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, destConnID) + if !hdr.OrigDestConnectionID.Equal(s.handshakeDestConnID) { + s.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, s.handshakeDestConnID) return false } - if hdr.SrcConnectionID.Equal(destConnID) { + if hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { s.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.") return false } @@ -732,13 +736,14 @@ func (s *session) handleRetryPacket(hdr *wire.Header) bool /* was this a valid R } s.logger.Debugf("<- Received Retry") s.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID) - s.origDestConnID = destConnID + s.origDestConnID = s.handshakeDestConnID newDestConnID := hdr.SrcConnectionID s.receivedRetry = true if err := s.sentPacketHandler.ResetForRetry(); err != nil { s.closeLocal(err) return false } + s.handshakeDestConnID = newDestConnID s.cryptoStreamHandler.ChangeConnectionID(newDestConnID) s.packer.SetToken(hdr.Token) s.connIDManager.ChangeInitialConnID(newDestConnID) @@ -752,9 +757,11 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time } // The server can change the source connection ID with the first Handshake packet. - if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.connIDManager.Get()) { - s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", packet.hdr.SrcConnectionID) - s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID) + if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + cid := packet.hdr.SrcConnectionID + s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid) + s.handshakeDestConnID = cid + s.connIDManager.ChangeInitialConnID(cid) } s.receivedFirstPacket = true diff --git a/session_test.go b/session_test.go index 3c762b971..a12703d32 100644 --- a/session_test.go +++ b/session_test.go @@ -1689,6 +1689,32 @@ var _ = Describe("Client Session", func() { Eventually(sess.Context().Done()).Should(BeClosed()) }) + It("continues accepting Long Header packets after using a new connection ID", func() { + unpacker := NewMockUnpacker(mockCtrl) + sess.unpacker = unpacker + sessionRunner.EXPECT().AddResetToken(gomock.Any(), gomock.Any()) + sess.handleNewConnectionIDFrame(&wire.NewConnectionIDFrame{ + SequenceNumber: 1, + ConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}, + }) + Expect(sess.connIDManager.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5})) + // now receive a packet with the original source connection ID + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(hdr *wire.Header, _ time.Time, _ []byte) (*unpackedPacket, error) { + return &unpackedPacket{ + hdr: &wire.ExtendedHeader{Header: *hdr}, + data: []byte{0}, + encryptionLevel: protocol.EncryptionHandshake, + }, nil + }) + hdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: srcConnID, + SrcConnectionID: destConnID, + } + Expect(sess.handleSinglePacket(&receivedPacket{buffer: getPacketBuffer()}, hdr)).To(BeTrue()) + }) + Context("handling tokens", func() { var mockTokenStore *MockTokenStore