forked from quic-go/quic-go
handle Version Negotiation packets in the session
This commit is contained in:
241
client_test.go
241
client_test.go
@@ -47,6 +47,7 @@ var _ = Describe("Client", func() {
|
||||
initialPacketNumber protocol.PacketNumber,
|
||||
initialVersion protocol.VersionNumber,
|
||||
enable0RTT bool,
|
||||
hasNegotiatedVersion bool,
|
||||
qlogger qlog.Tracer,
|
||||
logger utils.Logger,
|
||||
v protocol.VersionNumber,
|
||||
@@ -65,16 +66,6 @@ var _ = Describe("Client", func() {
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
composeVersionNegotiationPacket := func(connID protocol.ConnectionID, versions []protocol.VersionNumber) *receivedPacket {
|
||||
data, err := wire.ComposeVersionNegotiation(connID, nil, versions)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(wire.IsVersionNegotiationPacket(data)).To(BeTrue())
|
||||
return &receivedPacket{
|
||||
rcvTime: time.Now(),
|
||||
data: data,
|
||||
}
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
tlsConf = &tls.Config{NextProtos: []string{"proto1"}}
|
||||
connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37}
|
||||
@@ -169,6 +160,7 @@ var _ = Describe("Client", func() {
|
||||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
@@ -201,6 +193,7 @@ var _ = Describe("Client", func() {
|
||||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
@@ -233,6 +226,7 @@ var _ = Describe("Client", func() {
|
||||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
@@ -271,6 +265,7 @@ var _ = Describe("Client", func() {
|
||||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
enable0RTT bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
@@ -313,6 +308,7 @@ var _ = Describe("Client", func() {
|
||||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
enable0RTT bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
@@ -360,6 +356,7 @@ var _ = Describe("Client", func() {
|
||||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
@@ -403,6 +400,7 @@ var _ = Describe("Client", func() {
|
||||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
@@ -454,6 +452,7 @@ var _ = Describe("Client", func() {
|
||||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
@@ -574,6 +573,7 @@ var _ = Describe("Client", func() {
|
||||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber, /* initial version */
|
||||
_ bool,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
versionP protocol.VersionNumber,
|
||||
@@ -596,183 +596,58 @@ var _ = Describe("Client", func() {
|
||||
Expect(conf.Versions).To(Equal(config.Versions))
|
||||
})
|
||||
|
||||
Context("version negotiation", func() {
|
||||
var origSupportedVersions []protocol.VersionNumber
|
||||
It("creates a new session after version negotiation", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(connID, gomock.Any()).Times(2)
|
||||
manager.EXPECT().Destroy()
|
||||
mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
BeforeEach(func() {
|
||||
origSupportedVersions = protocol.SupportedVersions
|
||||
protocol.SupportedVersions = append(protocol.SupportedVersions, []protocol.VersionNumber{77, 78}...)
|
||||
})
|
||||
initialVersion := cl.version
|
||||
|
||||
AfterEach(func() {
|
||||
protocol.SupportedVersions = origSupportedVersions
|
||||
})
|
||||
|
||||
It("returns an error that occurs during version negotiation", func() {
|
||||
manager := NewMockPacketHandlerManager(mockCtrl)
|
||||
manager.EXPECT().Add(connID, gomock.Any())
|
||||
mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil)
|
||||
|
||||
testErr := errors.New("early handshake error")
|
||||
newClientSession = func(
|
||||
conn connection,
|
||||
_ sessionRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
_ *Config,
|
||||
_ *tls.Config,
|
||||
_ protocol.PacketNumber,
|
||||
_ protocol.VersionNumber,
|
||||
_ bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
_ protocol.VersionNumber,
|
||||
) quicSession {
|
||||
Expect(conn.Write([]byte("0 fake CHLO"))).To(Succeed())
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().run().Return(testErr)
|
||||
sess.EXPECT().HandshakeComplete().Return(context.Background())
|
||||
return sess
|
||||
var counter int
|
||||
newClientSession = func(
|
||||
_ connection,
|
||||
_ sessionRunner,
|
||||
_ protocol.ConnectionID,
|
||||
_ protocol.ConnectionID,
|
||||
configP *Config,
|
||||
_ *tls.Config,
|
||||
pn protocol.PacketNumber,
|
||||
version protocol.VersionNumber,
|
||||
_ bool,
|
||||
hasNegotiatedVersion bool,
|
||||
_ qlog.Tracer,
|
||||
_ utils.Logger,
|
||||
versionP protocol.VersionNumber,
|
||||
) quicSession {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().HandshakeComplete().Return(context.Background())
|
||||
if counter == 0 {
|
||||
Expect(pn).To(BeZero())
|
||||
Expect(version).To(Equal(initialVersion))
|
||||
Expect(hasNegotiatedVersion).To(BeFalse())
|
||||
sess.EXPECT().run().Return(&errCloseForRecreating{
|
||||
nextPacketNumber: 109,
|
||||
nextVersion: 789,
|
||||
})
|
||||
} else {
|
||||
Expect(pn).To(Equal(protocol.PacketNumber(109)))
|
||||
Expect(version).ToNot(Equal(initialVersion))
|
||||
Expect(version).To(Equal(protocol.VersionNumber(789)))
|
||||
Expect(hasNegotiatedVersion).To(BeTrue())
|
||||
sess.EXPECT().run()
|
||||
}
|
||||
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any())
|
||||
_, err := Dial(
|
||||
packetConn,
|
||||
addr,
|
||||
"localhost:1337",
|
||||
tlsConf,
|
||||
config,
|
||||
)
|
||||
Expect(err).To(MatchError(testErr))
|
||||
})
|
||||
counter++
|
||||
return sess
|
||||
}
|
||||
|
||||
It("recognizes that a non Version Negotiation packet means that the server accepted the suggested version", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().handlePacket(gomock.Any())
|
||||
cl.session = sess
|
||||
cl.config = config
|
||||
buf := &bytes.Buffer{}
|
||||
Expect((&wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
DestConnectionID: connID,
|
||||
SrcConnectionID: connID,
|
||||
Version: cl.version,
|
||||
},
|
||||
PacketNumberLen: protocol.PacketNumberLen3,
|
||||
}).Write(buf, protocol.VersionTLS)).To(Succeed())
|
||||
cl.handlePacket(&receivedPacket{data: buf.Bytes()})
|
||||
Eventually(cl.versionNegotiated.Get).Should(BeTrue())
|
||||
})
|
||||
|
||||
// Illustrates that adversary that injects a version negotiation packet
|
||||
// with no supported versions can break a connection.
|
||||
It("errors if no matching version is found", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
done := make(chan struct{})
|
||||
sess.EXPECT().destroy(gomock.Any()).Do(func(err error) {
|
||||
defer GinkgoRecover()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("No compatible QUIC version found."))
|
||||
close(done)
|
||||
})
|
||||
cl.session = sess
|
||||
cl.config = &Config{Versions: protocol.SupportedVersions}
|
||||
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1337})
|
||||
hdr, _, _, err := wire.ParsePacket(p.data, 0)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
qlogger.EXPECT().ReceivedVersionNegotiationPacket(hdr)
|
||||
cl.handlePacket(p)
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
done := make(chan struct{})
|
||||
sess.EXPECT().destroy(gomock.Any()).Do(func(err error) {
|
||||
defer GinkgoRecover()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("No compatible QUIC version found."))
|
||||
close(done)
|
||||
})
|
||||
cl.session = sess
|
||||
v := protocol.VersionNumber(1234)
|
||||
Expect(v).ToNot(Equal(cl.version))
|
||||
cl.config = &Config{Versions: protocol.SupportedVersions}
|
||||
qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any())
|
||||
cl.handlePacket(composeVersionNegotiationPacket(connID, []protocol.VersionNumber{v}))
|
||||
Eventually(done).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("changes to the version preferred by the quic.Config", func() {
|
||||
phm := NewMockPacketHandlerManager(mockCtrl)
|
||||
cl.packetHandlers = phm
|
||||
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
destroyed := make(chan struct{})
|
||||
sess.EXPECT().closeForRecreating().Do(func() {
|
||||
close(destroyed)
|
||||
})
|
||||
cl.session = sess
|
||||
versions := []protocol.VersionNumber{1234, 4321}
|
||||
cl.config = &Config{Versions: versions}
|
||||
qlogger.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any())
|
||||
cl.handlePacket(composeVersionNegotiationPacket(connID, versions))
|
||||
Eventually(destroyed).Should(BeClosed())
|
||||
Expect(cl.version).To(Equal(protocol.VersionNumber(1234)))
|
||||
})
|
||||
|
||||
It("drops unparseable version negotiation packets", func() {
|
||||
cl.config = config
|
||||
ver := cl.version
|
||||
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver})
|
||||
p.data = p.data[:len(p.data)-1]
|
||||
done := make(chan struct{})
|
||||
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropHeaderParseError).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) {
|
||||
close(done)
|
||||
})
|
||||
cl.handlePacket(p)
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(cl.version).To(Equal(ver))
|
||||
})
|
||||
|
||||
It("drops version negotiation packets if any other packet was received before", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().handlePacket(gomock.Any())
|
||||
cl.session = sess
|
||||
cl.config = config
|
||||
buf := &bytes.Buffer{}
|
||||
Expect((&wire.ExtendedHeader{
|
||||
Header: wire.Header{
|
||||
DestConnectionID: connID,
|
||||
SrcConnectionID: connID,
|
||||
Version: cl.version,
|
||||
},
|
||||
PacketNumberLen: protocol.PacketNumberLen3,
|
||||
}).Write(buf, protocol.VersionTLS)).To(Succeed())
|
||||
cl.handlePacket(&receivedPacket{data: buf.Bytes()})
|
||||
|
||||
ver := cl.version
|
||||
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{1234})
|
||||
done := make(chan struct{})
|
||||
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedPacket).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) {
|
||||
close(done)
|
||||
})
|
||||
cl.handlePacket(p)
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(cl.version).To(Equal(ver))
|
||||
})
|
||||
|
||||
It("drops version negotiation packets that contain the offered version", func() {
|
||||
cl.config = config
|
||||
ver := cl.version
|
||||
p := composeVersionNegotiationPacket(connID, []protocol.VersionNumber{ver})
|
||||
done := make(chan struct{})
|
||||
qlogger.EXPECT().DroppedPacket(qlog.PacketTypeVersionNegotiation, protocol.ByteCount(len(p.data)), qlog.PacketDropUnexpectedVersion).Do(func(qlog.PacketType, protocol.ByteCount, qlog.PacketDropReason) {
|
||||
close(done)
|
||||
})
|
||||
cl.handlePacket(p)
|
||||
Eventually(done).Should(BeClosed())
|
||||
Expect(cl.version).To(Equal(ver))
|
||||
})
|
||||
gomock.InOrder(
|
||||
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), initialVersion, gomock.Any(), gomock.Any()),
|
||||
qlogger.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionNumber(789), gomock.Any(), gomock.Any()),
|
||||
)
|
||||
_, err := DialAddr("localhost:7890", tlsConf, config)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(counter).To(Equal(2))
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user