forked from quic-go/quic-go
Merge pull request #2477 from lucas-clemente/fix-corrupted-scid
update the source connection ID if the session was created from an invalid Initial
This commit is contained in:
@@ -432,9 +432,6 @@ func (s *baseServer) createNewSession(
|
||||
qlogger = qlog.NewTracer(w, protocol.PerspectiveServer, connID)
|
||||
}
|
||||
}
|
||||
if qlogger != nil {
|
||||
qlogger.StartedConnection(s.conn.LocalAddr(), remoteAddr, version, srcConnID, destConnID)
|
||||
}
|
||||
sess := s.newSession(
|
||||
&conn{pconn: s.conn, currentAddr: remoteAddr},
|
||||
s.sessionHandler,
|
||||
|
||||
36
session.go
36
session.go
@@ -866,15 +866,37 @@ func (s *session) handleUnpackedPacket(
|
||||
return qerr.NewError(qerr.ProtocolViolation, "empty packet")
|
||||
}
|
||||
|
||||
// The server can change the source connection ID with the first Handshake packet.
|
||||
if s.perspective == protocol.PerspectiveClient && !s.receivedFirstPacket && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) {
|
||||
cid := packet.hdr.SrcConnectionID
|
||||
s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid)
|
||||
s.handshakeDestConnID = cid
|
||||
s.connIDManager.ChangeInitialConnID(cid)
|
||||
if !s.receivedFirstPacket {
|
||||
s.receivedFirstPacket = true
|
||||
// The server can change the source connection ID with the first Handshake packet.
|
||||
if s.perspective == protocol.PerspectiveClient && packet.hdr.IsLongHeader && !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) {
|
||||
cid := packet.hdr.SrcConnectionID
|
||||
s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid)
|
||||
s.handshakeDestConnID = cid
|
||||
s.connIDManager.ChangeInitialConnID(cid)
|
||||
}
|
||||
// We create the session as soon as we receive the first packet from the client.
|
||||
// We do that before authenticating the packet.
|
||||
// That means that if the source connection ID was corrupted,
|
||||
// we might have create a session with an incorrect source connection ID.
|
||||
// Once we authenticate the first packet, we need to update it.
|
||||
if s.perspective == protocol.PerspectiveServer {
|
||||
if !packet.hdr.SrcConnectionID.Equal(s.handshakeDestConnID) {
|
||||
s.handshakeDestConnID = packet.hdr.SrcConnectionID
|
||||
s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID)
|
||||
}
|
||||
if s.qlogger != nil {
|
||||
s.qlogger.StartedConnection(
|
||||
s.conn.LocalAddr(),
|
||||
s.conn.RemoteAddr(),
|
||||
s.version,
|
||||
packet.hdr.SrcConnectionID,
|
||||
packet.hdr.DestConnectionID,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.receivedFirstPacket = true
|
||||
s.lastPacketReceivedTime = rcvTime
|
||||
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
|
||||
s.keepAlivePingSent = false
|
||||
|
||||
@@ -51,6 +51,8 @@ var _ = Describe("Session", func() {
|
||||
cryptoSetup *mocks.MockCryptoSetup
|
||||
qlogger *mocks.MockTracer
|
||||
)
|
||||
remoteAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7331}
|
||||
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
|
||||
destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
|
||||
@@ -81,8 +83,8 @@ var _ = Describe("Session", func() {
|
||||
|
||||
sessionRunner = NewMockSessionRunner(mockCtrl)
|
||||
mconn = NewMockConnection(mockCtrl)
|
||||
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).Times(2)
|
||||
mconn.EXPECT().LocalAddr().Return(&net.UDPAddr{})
|
||||
mconn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
|
||||
mconn.EXPECT().LocalAddr().Return(localAddr).AnyTimes()
|
||||
tokenGenerator, err := handshake.NewTokenGenerator()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
qlogger = mocks.NewMockTracer(mockCtrl)
|
||||
@@ -526,16 +528,22 @@ var _ = Describe("Session", func() {
|
||||
streamManager.EXPECT().CloseWithError(gomock.Any())
|
||||
sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
buf := &bytes.Buffer{}
|
||||
Expect((&wire.ExtendedHeader{
|
||||
hdr := &wire.ExtendedHeader{
|
||||
Header: wire.Header{DestConnectionID: srcConnID},
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
}).Write(buf, sess.version)).To(Succeed())
|
||||
}
|
||||
Expect(hdr.Write(buf, sess.version)).To(Succeed())
|
||||
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*wire.Header, time.Time, []byte) (*unpackedPacket, error) {
|
||||
buf := &bytes.Buffer{}
|
||||
Expect((&wire.ConnectionCloseFrame{ErrorCode: qerr.StreamLimitError}).Write(buf, sess.version)).To(Succeed())
|
||||
return &unpackedPacket{data: buf.Bytes(), encryptionLevel: protocol.Encryption1RTT}, nil
|
||||
return &unpackedPacket{
|
||||
hdr: hdr,
|
||||
data: buf.Bytes(),
|
||||
encryptionLevel: protocol.Encryption1RTT,
|
||||
}, nil
|
||||
})
|
||||
gomock.InOrder(
|
||||
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
|
||||
qlogger.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()),
|
||||
qlogger.EXPECT().Export(),
|
||||
)
|
||||
@@ -595,6 +603,7 @@ var _ = Describe("Session", func() {
|
||||
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.EncryptionInitial, rcvTime, false)
|
||||
sess.receivedPacketHandler = rph
|
||||
packet.rcvTime = rcvTime
|
||||
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
qlogger.EXPECT().ReceivedPacket(hdr, protocol.ByteCount(len(packet.data)), nil)
|
||||
Expect(sess.handlePacketImpl(packet)).To(BeTrue())
|
||||
})
|
||||
@@ -619,6 +628,7 @@ var _ = Describe("Session", func() {
|
||||
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.Encryption1RTT, rcvTime, true)
|
||||
sess.receivedPacketHandler = rph
|
||||
packet.rcvTime = rcvTime
|
||||
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
qlogger.EXPECT().ReceivedPacket(hdr, protocol.ByteCount(len(packet.data)), []wire.Frame{&wire.PingFrame{}})
|
||||
Expect(sess.handlePacketImpl(packet)).To(BeTrue())
|
||||
})
|
||||
@@ -765,6 +775,7 @@ var _ = Describe("Session", func() {
|
||||
data: []byte{0}, // one PADDING frame
|
||||
}, nil)
|
||||
p1 := getPacket(hdr1, nil)
|
||||
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
qlogger.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(p1.data)), gomock.Any())
|
||||
Expect(sess.handlePacketImpl(p1)).To(BeTrue())
|
||||
// The next packet has to be ignored, since the source connection ID doesn't match.
|
||||
@@ -806,12 +817,16 @@ var _ = Describe("Session", func() {
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}, nil)
|
||||
packet.remoteAddr = &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
|
||||
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
qlogger.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any())
|
||||
Expect(sess.handlePacketImpl(packet)).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("coalesced packets", func() {
|
||||
BeforeEach(func() {
|
||||
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
|
||||
})
|
||||
getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) (int /* header length */, *receivedPacket) {
|
||||
hdr := &wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
@@ -838,6 +853,7 @@ var _ = Describe("Session", func() {
|
||||
return &unpackedPacket{
|
||||
encryptionLevel: protocol.EncryptionHandshake,
|
||||
data: []byte{0},
|
||||
hdr: &wire.ExtendedHeader{},
|
||||
}, nil
|
||||
})
|
||||
qlogger.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any())
|
||||
@@ -851,6 +867,7 @@ var _ = Describe("Session", func() {
|
||||
return &unpackedPacket{
|
||||
encryptionLevel: protocol.EncryptionHandshake,
|
||||
data: []byte{0},
|
||||
hdr: &wire.ExtendedHeader{Header: wire.Header{SrcConnectionID: destConnID}},
|
||||
}, nil
|
||||
})
|
||||
hdrLen2, packet2 := getPacketWithLength(srcConnID, 123)
|
||||
@@ -859,6 +876,7 @@ var _ = Describe("Session", func() {
|
||||
return &unpackedPacket{
|
||||
encryptionLevel: protocol.EncryptionHandshake,
|
||||
data: []byte{0},
|
||||
hdr: &wire.ExtendedHeader{Header: wire.Header{SrcConnectionID: destConnID}},
|
||||
}, nil
|
||||
})
|
||||
gomock.InOrder(
|
||||
@@ -880,6 +898,7 @@ var _ = Describe("Session", func() {
|
||||
return &unpackedPacket{
|
||||
encryptionLevel: protocol.EncryptionHandshake,
|
||||
data: []byte{0},
|
||||
hdr: &wire.ExtendedHeader{},
|
||||
}, nil
|
||||
}),
|
||||
)
|
||||
@@ -903,6 +922,7 @@ var _ = Describe("Session", func() {
|
||||
return &unpackedPacket{
|
||||
encryptionLevel: protocol.EncryptionHandshake,
|
||||
data: []byte{0},
|
||||
hdr: &wire.ExtendedHeader{},
|
||||
}, nil
|
||||
})
|
||||
_, packet2 := getPacketWithLength(wrongConnID, 123)
|
||||
@@ -1355,7 +1375,6 @@ var _ = Describe("Session", func() {
|
||||
}()
|
||||
handshakeCtx := sess.HandshakeComplete()
|
||||
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
|
||||
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
|
||||
close(finishHandshake)
|
||||
Eventually(handshakeCtx.Done()).Should(BeClosed())
|
||||
Eventually(sphNotified).Should(BeClosed())
|
||||
@@ -1387,7 +1406,6 @@ var _ = Describe("Session", func() {
|
||||
|
||||
handshakeCtx := sess.HandshakeComplete()
|
||||
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
|
||||
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
|
||||
close(finishHandshake)
|
||||
var frames []ackhandler.Frame
|
||||
Eventually(func() []ackhandler.Frame {
|
||||
@@ -1456,7 +1474,6 @@ var _ = Describe("Session", func() {
|
||||
cryptoSetup.EXPECT().RunHandshake()
|
||||
cryptoSetup.EXPECT().DropHandshakeKeys()
|
||||
cryptoSetup.EXPECT().GetSessionTicket()
|
||||
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
|
||||
mconn.EXPECT().Write(gomock.Any())
|
||||
close(sess.handshakeCompleteChan)
|
||||
sess.run()
|
||||
@@ -1713,7 +1730,6 @@ var _ = Describe("Session", func() {
|
||||
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
|
||||
cryptoSetup.EXPECT().DropHandshakeKeys().MaxTimes(1)
|
||||
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{})
|
||||
close(sess.handshakeCompleteChan)
|
||||
err := sess.run()
|
||||
nerr, ok := err.(net.Error)
|
||||
@@ -1809,15 +1825,11 @@ var _ = Describe("Session", func() {
|
||||
})
|
||||
|
||||
It("returns the local address", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
|
||||
mconn.EXPECT().LocalAddr().Return(addr)
|
||||
Expect(sess.LocalAddr()).To(Equal(addr))
|
||||
Expect(sess.LocalAddr()).To(Equal(localAddr))
|
||||
})
|
||||
|
||||
It("returns the remote address", func() {
|
||||
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 7, 1), Port: 7331}
|
||||
mconn.EXPECT().RemoteAddr().Return(addr)
|
||||
Expect(sess.RemoteAddr()).To(Equal(addr))
|
||||
Expect(sess.RemoteAddr()).To(Equal(remoteAddr))
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user