diff --git a/server.go b/server.go index c54e94b4..788a3fb4 100644 --- a/server.go +++ b/server.go @@ -432,9 +432,6 @@ func (s *baseServer) createNewSession( qlogger = qlog.NewTracer(w, protocol.PerspectiveServer, connID) } } - if qlogger != nil { - qlogger.StartedConnection(s.conn.LocalAddr(), remoteAddr, version, srcConnID, destConnID) - } sess := s.newSession( &conn{pconn: s.conn, currentAddr: remoteAddr}, s.sessionHandler, diff --git a/session.go b/session.go index 69d4b7e8..b452f8c3 100644 --- a/session.go +++ b/session.go @@ -866,15 +866,37 @@ func (s *session) handleUnpackedPacket( return qerr.NewError(qerr.ProtocolViolation, "empty packet") } - // The server can change the source connection ID with the first Handshake packet. - if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { - cid := packet.hdr.SrcConnectionID - s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid) - s.handshakeDestConnID = cid - s.connIDManager.ChangeInitialConnID(cid) + if !s.receivedFirstPacket { + s.receivedFirstPacket = true + // The server can change the source connection ID with the first Handshake packet. + if s.perspective == protocol.PerspectiveClient && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + cid := packet.hdr.SrcConnectionID + s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid) + s.handshakeDestConnID = cid + s.connIDManager.ChangeInitialConnID(cid) + } + // We create the session as soon as we receive the first packet from the client. + // We do that before authenticating the packet. + // That means that if the source connection ID was corrupted, + // we might have create a session with an incorrect source connection ID. + // Once we authenticate the first packet, we need to update it. + if s.perspective == protocol.PerspectiveServer { + if !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) { + s.handshakeDestConnID = packet.hdr.SrcConnectionID + s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID) + } + if s.qlogger != nil { + s.qlogger.StartedConnection( + s.conn.LocalAddr(), + s.conn.RemoteAddr(), + s.version, + packet.hdr.SrcConnectionID, + packet.hdr.DestConnectionID, + ) + } + } } - s.receivedFirstPacket = true s.lastPacketReceivedTime = rcvTime s.firstAckElicitingPacketAfterIdleSentTime = time.Time{} s.keepAlivePingSent = false diff --git a/session_test.go b/session_test.go index 448f7119..b5383606 100644 --- a/session_test.go +++ b/session_test.go @@ -51,6 +51,8 @@ var _ = Describe("Session", func() { cryptoSetup *mocks.MockCryptoSetup qlogger *mocks.MockTracer ) + remoteAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7331} srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} @@ -81,8 +83,8 @@ var _ = Describe("Session", func() { sessionRunner = NewMockSessionRunner(mockCtrl) mconn = NewMockConnection(mockCtrl) - mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).Times(2) - mconn.EXPECT().LocalAddr().Return(&net.UDPAddr{}) + mconn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes() + mconn.EXPECT().LocalAddr().Return(localAddr).AnyTimes() tokenGenerator, err := handshake.NewTokenGenerator() Expect(err).ToNot(HaveOccurred()) qlogger = mocks.NewMockTracer(mockCtrl) @@ -526,16 +528,22 @@ var _ = Describe("Session", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes() buf := &bytes.Buffer{} - Expect((&wire.ExtendedHeader{ + hdr := &wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: srcConnID}, PacketNumberLen: protocol.PacketNumberLen2, - }).Write(buf, sess.version)).To(Succeed()) + } + Expect(hdr.Write(buf, sess.version)).To(Succeed()) unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*wire.Header, time.Time, []byte) (*unpackedPacket, error) { buf := &bytes.Buffer{} Expect((&wire.ConnectionCloseFrame{ErrorCode: qerr.StreamLimitError}).Write(buf, sess.version)).To(Succeed()) - return &unpackedPacket{data: buf.Bytes(), encryptionLevel: protocol.Encryption1RTT}, nil + return &unpackedPacket{ + hdr: hdr, + data: buf.Bytes(), + encryptionLevel: protocol.Encryption1RTT, + }, nil }) gomock.InOrder( + qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), qlogger.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()), qlogger.EXPECT().Export(), ) @@ -595,6 +603,7 @@ var _ = Describe("Session", func() { rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.EncryptionInitial, rcvTime, false) sess.receivedPacketHandler = rph packet.rcvTime = rcvTime + qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) qlogger.EXPECT().ReceivedPacket(hdr, protocol.ByteCount(len(packet.data)), nil) Expect(sess.handlePacketImpl(packet)).To(BeTrue()) }) @@ -619,6 +628,7 @@ var _ = Describe("Session", func() { rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.Encryption1RTT, rcvTime, true) sess.receivedPacketHandler = rph packet.rcvTime = rcvTime + qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) qlogger.EXPECT().ReceivedPacket(hdr, protocol.ByteCount(len(packet.data)), []wire.Frame{&wire.PingFrame{}}) Expect(sess.handlePacketImpl(packet)).To(BeTrue()) }) @@ -765,6 +775,7 @@ var _ = Describe("Session", func() { data: []byte{0}, // one PADDING frame }, nil) p1 := getPacket(hdr1, nil) + qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) qlogger.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(p1.data)), gomock.Any()) Expect(sess.handlePacketImpl(p1)).To(BeTrue()) // The next packet has to be ignored, since the source connection ID doesn't match. @@ -806,12 +817,16 @@ var _ = Describe("Session", func() { PacketNumberLen: protocol.PacketNumberLen1, }, nil) packet.remoteAddr = &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} + qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) qlogger.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any()) Expect(sess.handlePacketImpl(packet)).To(BeTrue()) }) }) Context("coalesced packets", func() { + BeforeEach(func() { + qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) + }) getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) (int /* header length */, *receivedPacket) { hdr := &wire.ExtendedHeader{ Header: wire.Header{ @@ -838,6 +853,7 @@ var _ = Describe("Session", func() { return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, data: []byte{0}, + hdr: &wire.ExtendedHeader{}, }, nil }) qlogger.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any()) @@ -851,6 +867,7 @@ var _ = Describe("Session", func() { return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, data: []byte{0}, + hdr: &wire.ExtendedHeader{Header: wire.Header{SrcConnectionID: destConnID}}, }, nil }) hdrLen2, packet2 := getPacketWithLength(srcConnID, 123) @@ -859,6 +876,7 @@ var _ = Describe("Session", func() { return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, data: []byte{0}, + hdr: &wire.ExtendedHeader{Header: wire.Header{SrcConnectionID: destConnID}}, }, nil }) gomock.InOrder( @@ -880,6 +898,7 @@ var _ = Describe("Session", func() { return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, data: []byte{0}, + hdr: &wire.ExtendedHeader{}, }, nil }), ) @@ -903,6 +922,7 @@ var _ = Describe("Session", func() { return &unpackedPacket{ encryptionLevel: protocol.EncryptionHandshake, data: []byte{0}, + hdr: &wire.ExtendedHeader{}, }, nil }) _, packet2 := getPacketWithLength(wrongConnID, 123) @@ -1355,7 +1375,6 @@ var _ = Describe("Session", func() { }() handshakeCtx := sess.HandshakeComplete() Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) - mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token close(finishHandshake) Eventually(handshakeCtx.Done()).Should(BeClosed()) Eventually(sphNotified).Should(BeClosed()) @@ -1387,7 +1406,6 @@ var _ = Describe("Session", func() { handshakeCtx := sess.HandshakeComplete() Consistently(handshakeCtx.Done()).ShouldNot(BeClosed()) - mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token close(finishHandshake) var frames []ackhandler.Frame Eventually(func() []ackhandler.Frame { @@ -1456,7 +1474,6 @@ var _ = Describe("Session", func() { cryptoSetup.EXPECT().RunHandshake() cryptoSetup.EXPECT().DropHandshakeKeys() cryptoSetup.EXPECT().GetSessionTicket() - mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token mconn.EXPECT().Write(gomock.Any()) close(sess.handshakeCompleteChan) sess.run() @@ -1713,7 +1730,6 @@ var _ = Describe("Session", func() { cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) cryptoSetup.EXPECT().DropHandshakeKeys().MaxTimes(1) - mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) close(sess.handshakeCompleteChan) err := sess.run() nerr, ok := err.(net.Error) @@ -1809,15 +1825,11 @@ var _ = Describe("Session", func() { }) It("returns the local address", func() { - addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} - mconn.EXPECT().LocalAddr().Return(addr) - Expect(sess.LocalAddr()).To(Equal(addr)) + Expect(sess.LocalAddr()).To(Equal(localAddr)) }) It("returns the remote address", func() { - addr := &net.UDPAddr{IP: net.IPv4(1, 2, 7, 1), Port: 7331} - mconn.EXPECT().RemoteAddr().Return(addr) - Expect(sess.RemoteAddr()).To(Equal(addr)) + Expect(sess.RemoteAddr()).To(Equal(remoteAddr)) }) })