Merge pull request #2477 from lucas-clemente/fix-corrupted-scid

update the source connection ID if the session was created from an invalid Initial
This commit is contained in:
Marten Seemann
2020-04-09 00:45:15 +07:00
committed by GitHub
3 changed files with 56 additions and 25 deletions

View File

@@ -51,6 +51,8 @@ var _ = Describe("Session", func() {
cryptoSetup *mocks.MockCryptoSetup
qlogger *mocks.MockTracer
)
remoteAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 7331}
srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
@@ -81,8 +83,8 @@ var _ = Describe("Session", func() {
sessionRunner = NewMockSessionRunner(mockCtrl)
mconn = NewMockConnection(mockCtrl)
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}).Times(2)
mconn.EXPECT().LocalAddr().Return(&net.UDPAddr{})
mconn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
mconn.EXPECT().LocalAddr().Return(localAddr).AnyTimes()
tokenGenerator, err := handshake.NewTokenGenerator()
Expect(err).ToNot(HaveOccurred())
qlogger = mocks.NewMockTracer(mockCtrl)
@@ -526,16 +528,22 @@ var _ = Describe("Session", func() {
streamManager.EXPECT().CloseWithError(gomock.Any())
sessionRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes()
buf := &bytes.Buffer{}
Expect((&wire.ExtendedHeader{
hdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: srcConnID},
PacketNumberLen: protocol.PacketNumberLen2,
}).Write(buf, sess.version)).To(Succeed())
}
Expect(hdr.Write(buf, sess.version)).To(Succeed())
unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(*wire.Header, time.Time, []byte) (*unpackedPacket, error) {
buf := &bytes.Buffer{}
Expect((&wire.ConnectionCloseFrame{ErrorCode: qerr.StreamLimitError}).Write(buf, sess.version)).To(Succeed())
return &unpackedPacket{data: buf.Bytes(), encryptionLevel: protocol.Encryption1RTT}, nil
return &unpackedPacket{
hdr: hdr,
data: buf.Bytes(),
encryptionLevel: protocol.Encryption1RTT,
}, nil
})
gomock.InOrder(
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
qlogger.EXPECT().ReceivedPacket(gomock.Any(), gomock.Any(), gomock.Any()),
qlogger.EXPECT().Export(),
)
@@ -595,6 +603,7 @@ var _ = Describe("Session", func() {
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.EncryptionInitial, rcvTime, false)
sess.receivedPacketHandler = rph
packet.rcvTime = rcvTime
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
qlogger.EXPECT().ReceivedPacket(hdr, protocol.ByteCount(len(packet.data)), nil)
Expect(sess.handlePacketImpl(packet)).To(BeTrue())
})
@@ -619,6 +628,7 @@ var _ = Describe("Session", func() {
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.Encryption1RTT, rcvTime, true)
sess.receivedPacketHandler = rph
packet.rcvTime = rcvTime
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
qlogger.EXPECT().ReceivedPacket(hdr, protocol.ByteCount(len(packet.data)), []wire.Frame{&wire.PingFrame{}})
Expect(sess.handlePacketImpl(packet)).To(BeTrue())
})
@@ -765,6 +775,7 @@ var _ = Describe("Session", func() {
data: []byte{0}, // one PADDING frame
}, nil)
p1 := getPacket(hdr1, nil)
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
qlogger.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(p1.data)), gomock.Any())
Expect(sess.handlePacketImpl(p1)).To(BeTrue())
// The next packet has to be ignored, since the source connection ID doesn't match.
@@ -806,12 +817,16 @@ var _ = Describe("Session", func() {
PacketNumberLen: protocol.PacketNumberLen1,
}, nil)
packet.remoteAddr = &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)}
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
qlogger.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any())
Expect(sess.handlePacketImpl(packet)).To(BeTrue())
})
})
Context("coalesced packets", func() {
BeforeEach(func() {
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1)
})
getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) (int /* header length */, *receivedPacket) {
hdr := &wire.ExtendedHeader{
Header: wire.Header{
@@ -838,6 +853,7 @@ var _ = Describe("Session", func() {
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
data: []byte{0},
hdr: &wire.ExtendedHeader{},
}, nil
})
qlogger.EXPECT().ReceivedPacket(gomock.Any(), protocol.ByteCount(len(packet.data)), gomock.Any())
@@ -851,6 +867,7 @@ var _ = Describe("Session", func() {
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
data: []byte{0},
hdr: &wire.ExtendedHeader{Header: wire.Header{SrcConnectionID: destConnID}},
}, nil
})
hdrLen2, packet2 := getPacketWithLength(srcConnID, 123)
@@ -859,6 +876,7 @@ var _ = Describe("Session", func() {
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
data: []byte{0},
hdr: &wire.ExtendedHeader{Header: wire.Header{SrcConnectionID: destConnID}},
}, nil
})
gomock.InOrder(
@@ -880,6 +898,7 @@ var _ = Describe("Session", func() {
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
data: []byte{0},
hdr: &wire.ExtendedHeader{},
}, nil
}),
)
@@ -903,6 +922,7 @@ var _ = Describe("Session", func() {
return &unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
data: []byte{0},
hdr: &wire.ExtendedHeader{},
}, nil
})
_, packet2 := getPacketWithLength(wrongConnID, 123)
@@ -1355,7 +1375,6 @@ var _ = Describe("Session", func() {
}()
handshakeCtx := sess.HandshakeComplete()
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
close(finishHandshake)
Eventually(handshakeCtx.Done()).Should(BeClosed())
Eventually(sphNotified).Should(BeClosed())
@@ -1387,7 +1406,6 @@ var _ = Describe("Session", func() {
handshakeCtx := sess.HandshakeComplete()
Consistently(handshakeCtx.Done()).ShouldNot(BeClosed())
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
close(finishHandshake)
var frames []ackhandler.Frame
Eventually(func() []ackhandler.Frame {
@@ -1456,7 +1474,6 @@ var _ = Describe("Session", func() {
cryptoSetup.EXPECT().RunHandshake()
cryptoSetup.EXPECT().DropHandshakeKeys()
cryptoSetup.EXPECT().GetSessionTicket()
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{}) // the remote addr is needed for the token
mconn.EXPECT().Write(gomock.Any())
close(sess.handshakeCompleteChan)
sess.run()
@@ -1713,7 +1730,6 @@ var _ = Describe("Session", func() {
cryptoSetup.EXPECT().RunHandshake().MaxTimes(1)
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
cryptoSetup.EXPECT().DropHandshakeKeys().MaxTimes(1)
mconn.EXPECT().RemoteAddr().Return(&net.UDPAddr{})
close(sess.handshakeCompleteChan)
err := sess.run()
nerr, ok := err.(net.Error)
@@ -1809,15 +1825,11 @@ var _ = Describe("Session", func() {
})
It("returns the local address", func() {
addr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}
mconn.EXPECT().LocalAddr().Return(addr)
Expect(sess.LocalAddr()).To(Equal(addr))
Expect(sess.LocalAddr()).To(Equal(localAddr))
})
It("returns the remote address", func() {
addr := &net.UDPAddr{IP: net.IPv4(1, 2, 7, 1), Port: 7331}
mconn.EXPECT().RemoteAddr().Return(addr)
Expect(sess.RemoteAddr()).To(Equal(addr))
Expect(sess.RemoteAddr()).To(Equal(remoteAddr))
})
})