handle Version Negotiation packets in the session

This commit is contained in:
Marten Seemann
2020-06-30 17:13:50 +07:00
parent 6b42c7a045
commit 06ad477b9b
6 changed files with 226 additions and 298 deletions

View File

@@ -47,6 +47,7 @@ var _ = Describe("Client", func() {
initialPacketNumber protocol.PacketNumber,
initialVersion protocol.VersionNumber,
enable0RTT bool,
hasNegotiatedVersion bool,
qlogger qlog.Tracer,
logger utils.Logger,
v protocol.VersionNumber,
@@ -65,16 +66,6 @@ var _ = Describe("Client", func() {
return b.Bytes()
}
composeVersionNegotiationPacket := func(connID protocol.ConnectionID, versions []protocol.VersionNumber) *receivedPacket {
data, err := wire.ComposeVersionNegotiation(connID, nil, versions)
Expect(err).ToNot(HaveOccurred())
Expect(wire.IsVersionNegotiationPacket(data)).To(BeTrue())
return &receivedPacket{
rcvTime: time.Now(),
data: data,
}
}
BeforeEach(func() {
tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
@@ -169,6 +160,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber,
_ bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
@@ -201,6 +193,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber,
_ bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
@@ -233,6 +226,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber,
_ bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
@@ -271,6 +265,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber,
enable0RTT bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
@@ -313,6 +308,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber,
enable0RTT bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
@@ -360,6 +356,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber,
_ bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
@@ -403,6 +400,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber,
_ bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
@@ -454,6 +452,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber,
_ bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
@@ -574,6 +573,7 @@ var _ = Describe("Client", func() {
_ protocol.PacketNumber,
_ protocol.VersionNumber, /* initial version */
_ bool,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
versionP protocol.VersionNumber,
@@ -596,183 +596,58 @@ var _ = Describe("Client", func() {
Expect(conf.Versions).To(Equal(config.Versions))
})
Context("version negotiation", func() {
var origSupportedVersions []protocol.VersionNumber
It("creates a new session after version negotiation", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any()).Times(2)
manager.EXPECT().Destroy()
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
BeforeEach(func() {
origSupportedVersions = protocol.SupportedVersions
protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{77, 78}...)
})
initialVersion := cl.version
AfterEach(func() {
protocol.SupportedVersions = origSupportedVersions
})
It("returns an error that occurs during version negotiation", func() {
manager := NewMockPacketHandlerManager(mockCtrl)
manager.EXPECT().Add(connID, gomock.Any())
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil)
testErr := errors.New("early handshake error")
newClientSession = func(
conn connection,
_ sessionRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
_ *Config,
_ *tls.Config,
_ protocol.PacketNumber,
_ protocol.VersionNumber,
_ bool,
_ qlog.Tracer,
_ utils.Logger,
_ protocol.VersionNumber,
) quicSession {
Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed())
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().run().Return(testErr)
sess.EXPECT().HandshakeComplete().Return(context.Background())
return sess
var counter int
newClientSession = func(
_ connection,
_ sessionRunner,
_ protocol.ConnectionID,
_ protocol.ConnectionID,
configP *Config,
_ *tls.Config,
pn protocol.PacketNumber,
version protocol.VersionNumber,
_ bool,
hasNegotiatedVersion bool,
_ qlog.Tracer,
_ utils.Logger,
versionP protocol.VersionNumber,
) quicSession {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().HandshakeComplete().Return(context.Background())
if counter == 0 {
Expect(pn).To(BeZero())
Expect(version).To(Equal(initialVersion))
Expect(hasNegotiatedVersion).To(BeFalse())
sess.EXPECT().run().Return(&errCloseForRecreating{
nextPacketNumber: 109,
nextVersion: 789,
})
} else {
Expect(pn).To(Equal(protocol.PacketNumber(109)))
Expect(version).ToNot(Equal(initialVersion))
Expect(version).To(Equal(protocol.VersionNumber(789)))
Expect(hasNegotiatedVersion).To(BeTrue())
sess.EXPECT().run()
}
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
_, err := Dial(
packetConn,
addr,
"localhost:1337",
tlsConf,
config,
)
Expect(err).To(MatchError(testErr))
})
counter++
return sess
}
It("recognizes that a non Version Negotiation packet means that the server accepted the suggested version", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any())
cl.session = sess
cl.config = config
buf := &bytes.Buffer{}
Expect((&wire.ExtendedHeader{
Header: wire.Header{
DestConnectionID: connID,
SrcConnectionID: connID,
Version: cl.version,
},
PacketNumberLen: protocol.PacketNumberLen3,
}).Write(buf, protocol.VersionTLS)).To(Succeed())
cl.handlePacket(&receivedPacket{data: buf.Bytes()})
Eventually(cl.versionNegotiated.Get).Should(BeTrue())
})
// Illustrates that adversary that injects a version negotiation packet
// with no supported versions can break a connection.
It("errors if no matching version is found", func() {
sess := NewMockQuicSession(mockCtrl)
done := make(chan struct{})
sess.EXPECT().destroy(gomock.Any()).Do(func(err error) {
defer GinkgoRecover()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("No compatible QUIC version found."))
close(done)
})
cl.session = sess
cl.config = &Config{Versions: protocol.SupportedVersions}
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1337})
hdr, _, _, err := wire.ParsePacket(p.data, 0)
Expect(err).ToNot(HaveOccurred())
qlogger.EXPECT().ReceivedVersionNegotiationPacket(hdr)
cl.handlePacket(p)
Eventually(done).Should(BeClosed())
})
It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() {
sess := NewMockQuicSession(mockCtrl)
done := make(chan struct{})
sess.EXPECT().destroy(gomock.Any()).Do(func(err error) {
defer GinkgoRecover()
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("No compatible QUIC version found."))
close(done)
})
cl.session = sess
v := protocol.VersionNumber(1234)
Expect(v).ToNot(Equal(cl.version))
cl.config = &Config{Versions: protocol.SupportedVersions}
qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any())
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{v}))
Eventually(done).Should(BeClosed())
})
It("changes to the version preferred by the quic.Config", func() {
phm := NewMockPacketHandlerManager(mockCtrl)
cl.packetHandlers = phm
sess := NewMockQuicSession(mockCtrl)
destroyed := make(chan struct{})
sess.EXPECT().closeForRecreating().Do(func() {
close(destroyed)
})
cl.session = sess
versions := []protocol.VersionNumber{1234, 4321}
cl.config = &Config{Versions: versions}
qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any())
cl.handlePacket(composeVersionNegotiationPacket(connID, versions))
Eventually(destroyed).Should(BeClosed())
Expect(cl.version).To(Equal(protocol.VersionNumber(1234)))
})
It("drops unparseable version negotiation packets", func() {
cl.config = config
ver := cl.version
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver})
p.data = p.data[:len(p.data)-1]
done := make(chan struct{})
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) {
close(done)
})
cl.handlePacket(p)
Eventually(done).Should(BeClosed())
Expect(cl.version).To(Equal(ver))
})
It("drops version negotiation packets if any other packet was received before", func() {
sess := NewMockQuicSession(mockCtrl)
sess.EXPECT().handlePacket(gomock.Any())
cl.session = sess
cl.config = config
buf := &bytes.Buffer{}
Expect((&wire.ExtendedHeader{
Header: wire.Header{
DestConnectionID: connID,
SrcConnectionID: connID,
Version: cl.version,
},
PacketNumberLen: protocol.PacketNumberLen3,
}).Write(buf, protocol.VersionTLS)).To(Succeed())
cl.handlePacket(&receivedPacket{data: buf.Bytes()})
ver := cl.version
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1234})
done := make(chan struct{})
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedPacket).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) {
close(done)
})
cl.handlePacket(p)
Eventually(done).Should(BeClosed())
Expect(cl.version).To(Equal(ver))
})
It("drops version negotiation packets that contain the offered version", func() {
cl.config = config
ver := cl.version
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver})
done := make(chan struct{})
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) {
close(done)
})
cl.handlePacket(p)
Eventually(done).Should(BeClosed())
Expect(cl.version).To(Equal(ver))
})
gomock.InOrder(
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), initialVersion, gomock.Any(), gomock.Any()),
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionNumber(789), gomock.Any(), gomock.Any()),
)
_, err := DialAddr("localhost:7890", tlsConf, config)
Expect(err).ToNot(HaveOccurred())
Expect(counter).To(Equal(2))
})
})