diff --git a/session.go b/session.go index 0c2f3beb0..bb71ad65d 100644 --- a/session.go +++ b/session.go @@ -600,6 +600,13 @@ func (s *session) handleHandshakeEvent(completed bool) { } func (s *session) handlePacketImpl(p *receivedPacket) error { + hdr := p.header + // The server can change the source connection ID with the first Handshake packet. + // After this, all packets with a different source connection have to be ignored. + if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) { + s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.header.SrcConnectionID, s.destConnID) + return nil + } if s.perspective == protocol.PerspectiveClient { if divNonce := p.header.DiversificationNonce; len(divNonce) > 0 { if err := s.cryptoStreamHandler.(divNonceSetter).SetDiversificationNonce(divNonce); err != nil { @@ -613,9 +620,6 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { p.rcvTime = time.Now() } - hdr := p.header - data := p.data - // Calculate packet number hdr.PacketNumber = protocol.InferPacketNumber( hdr.PacketNumberLen, @@ -624,12 +628,12 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { s.version, ) - packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data) + packet, err := s.unpacker.Unpack(hdr.Raw, hdr, p.data) if s.logger.Debug() { if err != nil { - s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.DestConnectionID) + s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(p.data)+len(hdr.Raw), hdr.DestConnectionID) } else { - s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s, %s", hdr.PacketNumber, len(data)+len(hdr.Raw), hdr.DestConnectionID, packet.encryptionLevel) + s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s, %s", hdr.PacketNumber, len(p.data)+len(hdr.Raw), hdr.DestConnectionID, packet.encryptionLevel) } hdr.Log(s.logger) } @@ -638,6 +642,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) error { return err } + // The server can change the source connection ID with the first Handshake packet. if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && !hdr.SrcConnectionID.Equal(s.destConnID) { s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", hdr.SrcConnectionID) s.destConnID = hdr.SrcConnectionID diff --git a/session_test.go b/session_test.go index 80240f0ef..58fff6ceb 100644 --- a/session_test.go +++ b/session_test.go @@ -659,6 +659,29 @@ var _ = Describe("Session", func() { Expect(err).ToNot(HaveOccurred()) }) + It("ignores packets with a different source connection ID", func() { + // Send one packet, which might change the connection ID. + // only EXPECT one call to the unpacker + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil) + err := sess.handlePacketImpl(&receivedPacket{ + header: &wire.Header{ + IsLongHeader: true, + DestConnectionID: sess.destConnID, + SrcConnectionID: sess.srcConnID, + }, + }) + Expect(err).ToNot(HaveOccurred()) + // The next packet has to be ignored, since the source connection ID doesn't match. + err = sess.handlePacketImpl(&receivedPacket{ + header: &wire.Header{ + IsLongHeader: true, + DestConnectionID: sess.destConnID, + SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + }, + }) + Expect(err).ToNot(HaveOccurred()) + }) + Context("updating the remote address", func() { It("doesn't support connection migration", func() { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{}, nil)