Merge pull request #2477 from lucas-clemente/fix-corrupted-scid

update the source connection ID if the session was created from an invalid Initial
This commit is contained in:
Marten Seemann
2020-04-09 00:45:15 +07:00
committed by GitHub
3 changed files with 56 additions and 25 deletions

View File

@@ -432,9 +432,6 @@ func (s *baseServer) createNewSession(
qlogger = qlog.NewTracer(w, protocol.PerspectiveServer, connID)
}
}
if qlogger != nil {
qlogger.StartedConnection(s.conn.LocalAddr(), remoteAddr, version, srcConnID, destConnID)
}
sess := s.newSession(
&conn{pconn: s.conn, currentAddr: remoteAddr},
s.sessionHandler,

View File

@@ -866,15 +866,37 @@ func (s *session) handleUnpackedPacket(
return qerr.NewError(qerr.ProtocolViolation, "empty packet")
}
// The server can change the source connection ID with the first Handshake packet.
if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) {
cid := packet.hdr.SrcConnectionID
s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid)
s.handshakeDestConnID = cid
s.connIDManager.ChangeInitialConnID(cid)
if !s.receivedFirstPacket {
s.receivedFirstPacket = true
// The server can change the source connection ID with the first Handshake packet.
if s.perspective == protocol.PerspectiveClient && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) {
cid := packet.hdr.SrcConnectionID
s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid)
s.handshakeDestConnID = cid
s.connIDManager.ChangeInitialConnID(cid)
}
// We create the session as soon as we receive the first packet from the client.
// We do that before authenticating the packet.
// That means that if the source connection ID was corrupted,
// we might have create a session with an incorrect source connection ID.
// Once we authenticate the first packet, we need to update it.
if s.perspective == protocol.PerspectiveServer {
if !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) {
s.handshakeDestConnID = packet.hdr.SrcConnectionID
s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID)
}
if s.qlogger != nil {
s.qlogger.StartedConnection(
s.conn.LocalAddr(),
s.conn.RemoteAddr(),
s.version,
packet.hdr.SrcConnectionID,
packet.hdr.DestConnectionID,
)
}
}
}
s.receivedFirstPacket = true
s.lastPacketReceivedTime = rcvTime
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
s.keepAlivePingSent = false

View File

@@ -51,6 +51,8 @@ var _ = Describe("Session", func() {
cryptoSetup *mocks.MockCryptoSetup
qlogger *mocks.MockTracer
)
remoteAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7331}
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
@@ -81,8 +83,8 @@ var _ = Describe("Session", func() {
sessionRunner = NewMockSessionRunner(mockCtrl)
mconn = NewMockConnection(mockCtrl)
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).Times(2)
mconn.EXPECT().LocalAddr().Return(&net.UDPAddr{})
mconn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
mconn.EXPECT().LocalAddr().Return(localAddr).AnyTimes()
tokenGenerator, err := handshake.NewTokenGenerator()
Expect(err).ToNot(HaveOccurred())
qlogger = mocks.NewMockTracer(mockCtrl)
@@ -526,16 +528,22 @@ var _ = Describe("Session", func() {
streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes()
buf := &bytes.Buffer{}
Expect((&wire.ExtendedHeader{
hdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID},
PacketNumberLen: protocol.PacketNumberLen2,
}).Write(buf, sess.version)).To(Succeed())
}
Expect(hdr.Write(buf, sess.version)).To(Succeed())
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*wire.Header, time.Time, []byte) (*unpackedPacket, error) {
buf := &bytes.Buffer{}
Expect((&wire.ConnectionCloseFrame{ErrorCode: qerr.StreamLimitError}).Write(buf, sess.version)).To(Succeed())
return &unpackedPacket{data: buf.Bytes(), encryptionLevel: protocol.Encryption1RTT}, nil
return &unpackedPacket{
hdr: hdr,
data: buf.Bytes(),
encryptionLevel: protocol.Encryption1RTT,
}, nil
})
gomock.InOrder(
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
qlogger.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()),
qlogger.EXPECT().Export(),
)
@@ -595,6 +603,7 @@ var _ = Describe("Session", func() {
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.EncryptionInitial, rcvTime, false)
sess.receivedPacketHandler = rph
packet.rcvTime = rcvTime
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
qlogger.EXPECT().ReceivedPacket(hdr, protocol.ByteCount(len(packet.data)), nil)
Expect(sess.handlePacketImpl(packet)).To(BeTrue())
})
@@ -619,6 +628,7 @@ var _ = Describe("Session", func() {
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.Encryption1RTT, rcvTime, true)
sess.receivedPacketHandler = rph
packet.rcvTime = rcvTime
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
qlogger.EXPECT().ReceivedPacket(hdr, protocol.ByteCount(len(packet.data)), []wire.Frame{&wire.PingFrame{}})
Expect(sess.handlePacketImpl(packet)).To(BeTrue())
})
@@ -765,6 +775,7 @@ var _ = Describe("Session", func() {
data: []byte{0}, // one PADDING frame
}, nil)
p1 := getPacket(hdr1, nil)
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
qlogger.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(p1.data)), gomock.Any())
Expect(sess.handlePacketImpl(p1)).To(BeTrue())
// The next packet has to be ignored, since the source connection ID doesn't match.
@@ -806,12 +817,16 @@ var _ = Describe("Session", func() {
PacketNumberLen: protocol.PacketNumberLen1,
}, nil)
packet.remoteAddr = &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
qlogger.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any())
Expect(sess.handlePacketImpl(packet)).To(BeTrue())
})
})
Context("coalesced packets", func() {
BeforeEach(func() {
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
})
getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) (int /* header length */, *receivedPacket) {
hdr := &wire.ExtendedHeader{
Header: wire.Header{
@@ -838,6 +853,7 @@ var _ = Describe("Session", func() {
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
data: []byte{0},
hdr: &wire.ExtendedHeader{},
}, nil
})
qlogger.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any())
@@ -851,6 +867,7 @@ var _ = Describe("Session", func() {
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
data: []byte{0},
hdr: &wire.ExtendedHeader{Header: wire.Header{SrcConnectionID: destConnID}},
}, nil
})
hdrLen2, packet2 := getPacketWithLength(srcConnID, 123)
@@ -859,6 +876,7 @@ var _ = Describe("Session", func() {
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
data: []byte{0},
hdr: &wire.ExtendedHeader{Header: wire.Header{SrcConnectionID: destConnID}},
}, nil
})
gomock.InOrder(
@@ -880,6 +898,7 @@ var _ = Describe("Session", func() {
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
data: []byte{0},
hdr: &wire.ExtendedHeader{},
}, nil
}),
)
@@ -903,6 +922,7 @@ var _ = Describe("Session", func() {
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
data: []byte{0},
hdr: &wire.ExtendedHeader{},
}, nil
})
_, packet2 := getPacketWithLength(wrongConnID, 123)
@@ -1355,7 +1375,6 @@ var _ = Describe("Session", func() {
}()
handshakeCtx := sess.HandshakeComplete()
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
close(finishHandshake)
Eventually(handshakeCtx.Done()).Should(BeClosed())
Eventually(sphNotified).Should(BeClosed())
@@ -1387,7 +1406,6 @@ var _ = Describe("Session", func() {
handshakeCtx := sess.HandshakeComplete()
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
close(finishHandshake)
var frames []ackhandler.Frame
Eventually(func() []ackhandler.Frame {
@@ -1456,7 +1474,6 @@ var _ = Describe("Session", func() {
cryptoSetup.EXPECT().RunHandshake()
cryptoSetup.EXPECT().DropHandshakeKeys()
cryptoSetup.EXPECT().GetSessionTicket()
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
mconn.EXPECT().Write(gomock.Any())
close(sess.handshakeCompleteChan)
sess.run()
@@ -1713,7 +1730,6 @@ var _ = Describe("Session", func() {
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
cryptoSetup.EXPECT().DropHandshakeKeys().MaxTimes(1)
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{})
close(sess.handshakeCompleteChan)
err := sess.run()
nerr, ok := err.(net.Error)
@@ -1809,15 +1825,11 @@ var _ = Describe("Session", func() {
})
It("returns the local address", func() {
addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
mconn.EXPECT().LocalAddr().Return(addr)
Expect(sess.LocalAddr()).To(Equal(addr))
Expect(sess.LocalAddr()).To(Equal(localAddr))
})
It("returns the remote address", func() {
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 7, 1), Port: 7331}
mconn.EXPECT().RemoteAddr().Return(addr)
Expect(sess.RemoteAddr()).To(Equal(addr))
Expect(sess.RemoteAddr()).To(Equal(remoteAddr))
})
})