From 6c3876d6b3d7493f1d2521555560acf5df1178f5 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 17 Mar 2021 18:18:09 +0800 Subject: [PATCH 1/2] allow 0-RTT when the server's stream receive limit is increased --- integrationtests/self/zero_rtt_test.go | 95 ++++++++++++++++++++--- internal/wire/transport_parameter_test.go | 31 ++++++-- internal/wire/transport_parameters.go | 6 +- 3 files changed, 112 insertions(+), 20 deletions(-) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 3cbae9c6..476cd16d 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -22,21 +22,26 @@ import ( . "github.com/onsi/gomega" ) +type rcvdPacket struct { + hdr *logging.ExtendedHeader + frames []logging.Frame +} + type rcvdPacketTracer struct { connTracer closed chan struct{} - rcvdPackets []*logging.ExtendedHeader + rcvdPackets []rcvdPacket } func newRcvdPacketTracer() *rcvdPacketTracer { return &rcvdPacketTracer{closed: make(chan struct{})} } -func (t *rcvdPacketTracer) ReceivedPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ []logging.Frame) { - t.rcvdPackets = append(t.rcvdPackets, hdr) +func (t *rcvdPacketTracer) ReceivedPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, frames []logging.Frame) { + t.rcvdPackets = append(t.rcvdPackets, rcvdPacket{hdr: hdr, frames: frames}) } func (t *rcvdPacketTracer) Close() { close(t.closed) } -func (t *rcvdPacketTracer) getRcvdPackets() []*logging.ExtendedHeader { +func (t *rcvdPacketTracer) getRcvdPackets() []rcvdPacket { <-t.closed return t.rcvdPackets } @@ -187,11 +192,11 @@ var _ = Describe("0-RTT", func() { } // can be used to extract 0-RTT from a rcvdPacketTracer - get0RTTPackets := func(hdrs []*logging.ExtendedHeader) []protocol.PacketNumber { + get0RTTPackets := func(packets []rcvdPacket) []protocol.PacketNumber { var zeroRTTPackets []protocol.PacketNumber - for _, hdr := range hdrs { - if hdr.Type == protocol.PacketType0RTT { - zeroRTTPackets = append(zeroRTTPackets, hdr.PacketNumber) + for _, p := range packets { + if p.hdr.Type == protocol.PacketType0RTT { + zeroRTTPackets = append(zeroRTTPackets, p.hdr.PacketNumber) } } return zeroRTTPackets @@ -545,6 +550,78 @@ var _ = Describe("0-RTT", func() { Expect(get0RTTPackets(tracer.getRcvdPackets())).To(BeEmpty()) }) + It("doesn't reject 0-RTT when the server's transport stream flow control limit increased", func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ + InitialStreamReceiveWindow: 3, + AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, + })) + + tracer := newRcvdPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Versions: []protocol.VersionNumber{version}, + AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, + InitialStreamReceiveWindow: 100, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + sess, err := quic.DialAddrEarly( + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := sess.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + written := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(written) + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }() + + Eventually(written).Should(BeClosed()) + + serverSess, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + rstr, err := serverSess.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := ioutil.ReadAll(rstr) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + Expect(serverSess.ConnectionState().TLS.Used0RTT).To(BeTrue()) + Expect(serverSess.CloseWithError(0, "")).To(Succeed()) + Eventually(sess.Context().Done()).Should(BeClosed()) + + var processedFirst bool + for _, p := range tracer.getRcvdPackets() { + for _, f := range p.frames { + if sf, ok := f.(*logging.StreamFrame); ok { + if !processedFirst { + // The first STREAM should have been sent in a 0-RTT packet. + // Due to the flow control limit, the STREAM frame was limit to the first 3 bytes. + Expect(p.hdr.Type).To(Equal(protocol.PacketType0RTT)) + Expect(sf.Length).To(BeEquivalentTo(3)) + processedFirst = true + } else { + // All other STREAM frames can only be sent after handshake completion. + Expect(p.hdr.IsLongHeader).To(BeFalse()) + Expect(sf.Offset).ToNot(BeZero()) + } + } + } + } + }) + It("correctly deals with 0-RTT rejections", func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) // now dial new connection with different transport parameters @@ -648,7 +725,7 @@ var _ = Describe("0-RTT", func() { transfer0RTTData(ln, proxy.LocalPort(), clientConf, PRData, true) - Expect(tracer.rcvdPackets[0].Type).To(Equal(protocol.PacketTypeInitial)) + Expect(tracer.rcvdPackets[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) zeroRTTPackets := get0RTTPackets(tracer.getRcvdPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go index f4b3a80f..811c47ea 100644 --- a/internal/wire/transport_parameter_test.go +++ b/internal/wire/transport_parameter_test.go @@ -498,28 +498,43 @@ var _ = Describe("Transport Parameters", func() { Expect(p.ValidFor0RTT(saved)).To(BeTrue()) }) - It("rejects the parameters if the InitialMaxStreamDataBidiLocal changed", func() { - p.InitialMaxStreamDataBidiLocal = 0 + It("rejects the parameters if the InitialMaxStreamDataBidiLocal was reduced", func() { + p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal - 1 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) - It("rejects the parameters if the InitialMaxStreamDataBidiRemote changed", func() { - p.InitialMaxStreamDataBidiRemote = 0 + It("doesn't reject the parameters if the InitialMaxStreamDataBidiLocal was increased", func() { + p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the InitialMaxStreamDataBidiRemote was reduced", func() { + p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote - 1 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) - It("rejects the parameters if the InitialMaxStreamDataUni changed", func() { - p.InitialMaxStreamDataUni = 0 + It("doesn't reject the parameters if the InitialMaxStreamDataBidiRemote was increased", func() { + p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the InitialMaxStreamDataUni was reduced", func() { + p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni - 1 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) + It("doesn't reject the parameters if the InitialMaxStreamDataUni was increased", func() { + p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + It("rejects the parameters if the InitialMaxData changed", func() { p.InitialMaxData = 0 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) - It("rejects the parameters if the MaxBidiStreamNum changed", func() { - p.MaxBidiStreamNum = 0 + It("rejects the parameters if the MaxBidiStreamNum was reduced", func() { + p.MaxBidiStreamNum = saved.MaxBidiStreamNum - 1 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go index 6bf437dc..959d8257 100644 --- a/internal/wire/transport_parameters.go +++ b/internal/wire/transport_parameters.go @@ -441,9 +441,9 @@ func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error // ValidFor0RTT checks if the transport parameters match those saved in the session ticket. func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool { - return p.InitialMaxStreamDataBidiLocal == saved.InitialMaxStreamDataBidiLocal && - p.InitialMaxStreamDataBidiRemote == saved.InitialMaxStreamDataBidiRemote && - p.InitialMaxStreamDataUni == saved.InitialMaxStreamDataUni && + return p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal && + p.InitialMaxStreamDataBidiRemote >= saved.InitialMaxStreamDataBidiRemote && + p.InitialMaxStreamDataUni >= saved.InitialMaxStreamDataUni && p.InitialMaxData == saved.InitialMaxData && p.MaxBidiStreamNum >= saved.MaxBidiStreamNum && p.MaxUniStreamNum >= saved.MaxUniStreamNum && From 31ac5ca60de00a7c68a4466bdfda4131097defcd Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 17 Mar 2021 18:45:56 +0800 Subject: [PATCH 2/2] allow 0-RTT when the server's connection receive limit is increased --- integrationtests/self/zero_rtt_test.go | 134 ++++++++++++---------- internal/handshake/crypto_setup_test.go | 4 +- internal/wire/transport_parameter_test.go | 9 +- internal/wire/transport_parameters.go | 2 +- 4 files changed, 81 insertions(+), 68 deletions(-) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 476cd16d..72b02e36 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -19,6 +19,7 @@ import ( "github.com/lucas-clemente/quic-go/logging" . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" ) @@ -550,77 +551,84 @@ var _ = Describe("0-RTT", func() { Expect(get0RTTPackets(tracer.getRcvdPackets())).To(BeEmpty()) }) - It("doesn't reject 0-RTT when the server's transport stream flow control limit increased", func() { - tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ - InitialStreamReceiveWindow: 3, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - })) + DescribeTable("flow control limits", + func(addFlowControlLimit func(*quic.Config, uint64)) { + tracer := newRcvdPacketTracer() + firstConf := getQuicConfig(&quic.Config{ + AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, + Versions: []protocol.VersionNumber{version}, + }) + addFlowControlLimit(firstConf, 3) + tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) - tracer := newRcvdPacketTracer() - ln, err := quic.ListenAddrEarly( - "localhost:0", - tlsConf, - getQuicConfig(&quic.Config{ - Versions: []protocol.VersionNumber{version}, - AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, - InitialStreamReceiveWindow: 100, - Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), - }), - ) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) - defer proxy.Close() - - sess, err := quic.DialAddrEarly( - fmt.Sprintf("localhost:%d", proxy.LocalPort()), - clientConf, - getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), - ) - Expect(err).ToNot(HaveOccurred()) - str, err := sess.OpenUniStream() - Expect(err).ToNot(HaveOccurred()) - written := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(written) - _, err := str.Write([]byte("foobar")) + secondConf := getQuicConfig(&quic.Config{ + Versions: []protocol.VersionNumber{version}, + AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + }) + addFlowControlLimit(secondConf, 100) + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + secondConf, + ) Expect(err).ToNot(HaveOccurred()) - Expect(str.Close()).To(Succeed()) - }() + defer ln.Close() + proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() - Eventually(written).Should(BeClosed()) + sess, err := quic.DialAddrEarly( + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := sess.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + written := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(written) + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }() - serverSess, err := ln.Accept(context.Background()) - Expect(err).ToNot(HaveOccurred()) - rstr, err := serverSess.AcceptUniStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - data, err := ioutil.ReadAll(rstr) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal([]byte("foobar"))) - Expect(serverSess.ConnectionState().TLS.Used0RTT).To(BeTrue()) - Expect(serverSess.CloseWithError(0, "")).To(Succeed()) - Eventually(sess.Context().Done()).Should(BeClosed()) + Eventually(written).Should(BeClosed()) - var processedFirst bool - for _, p := range tracer.getRcvdPackets() { - for _, f := range p.frames { - if sf, ok := f.(*logging.StreamFrame); ok { - if !processedFirst { - // The first STREAM should have been sent in a 0-RTT packet. - // Due to the flow control limit, the STREAM frame was limit to the first 3 bytes. - Expect(p.hdr.Type).To(Equal(protocol.PacketType0RTT)) - Expect(sf.Length).To(BeEquivalentTo(3)) - processedFirst = true - } else { - // All other STREAM frames can only be sent after handshake completion. - Expect(p.hdr.IsLongHeader).To(BeFalse()) - Expect(sf.Offset).ToNot(BeZero()) + serverSess, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + rstr, err := serverSess.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := ioutil.ReadAll(rstr) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + Expect(serverSess.ConnectionState().TLS.Used0RTT).To(BeTrue()) + Expect(serverSess.CloseWithError(0, "")).To(Succeed()) + Eventually(sess.Context().Done()).Should(BeClosed()) + + var processedFirst bool + for _, p := range tracer.getRcvdPackets() { + for _, f := range p.frames { + if sf, ok := f.(*logging.StreamFrame); ok { + if !processedFirst { + // The first STREAM should have been sent in a 0-RTT packet. + // Due to the flow control limit, the STREAM frame was limit to the first 3 bytes. + Expect(p.hdr.Type).To(Equal(protocol.PacketType0RTT)) + Expect(sf.Length).To(BeEquivalentTo(3)) + processedFirst = true + } else { + // All other STREAM frames can only be sent after handshake completion. + Expect(p.hdr.IsLongHeader).To(BeFalse()) + Expect(sf.Offset).ToNot(BeZero()) + } } } } - } - }) + }, + Entry("doesn't reject 0-RTT when the server's transport stream flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialStreamReceiveWindow = limit }), + Entry("doesn't reject 0-RTT when the server's transport connection flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialConnectionReceiveWindow = limit }), + ) It("correctly deals with 0-RTT rejections", func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 8a663289..d09673e9 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -776,7 +776,7 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(client.ConnectionState().Used0RTT).To(BeTrue()) }) - It("rejects 0-RTT, whent the transport parameters changed", func() { + It("rejects 0-RTT, when the transport parameters changed", func() { csc := mocktls.NewMockClientSessionCache(mockCtrl) var state *tls.ClientSessionState receivedSessionTicket := make(chan struct{}) @@ -810,7 +810,7 @@ var _ = Describe("Crypto Setup TLS", func() { clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( clientConf, serverConf, clientRTTStats, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData + 1}, + &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData - 1}, true, ) Expect(clientErr).ToNot(HaveOccurred()) diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go index 811c47ea..283fdab6 100644 --- a/internal/wire/transport_parameter_test.go +++ b/internal/wire/transport_parameter_test.go @@ -528,11 +528,16 @@ var _ = Describe("Transport Parameters", func() { Expect(p.ValidFor0RTT(saved)).To(BeTrue()) }) - It("rejects the parameters if the InitialMaxData changed", func() { - p.InitialMaxData = 0 + It("rejects the parameters if the InitialMaxData was reduced", func() { + p.InitialMaxData = saved.InitialMaxData - 1 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) + It("doesn't reject the parameters if the InitialMaxData was increased", func() { + p.InitialMaxData = saved.InitialMaxData + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + It("rejects the parameters if the MaxBidiStreamNum was reduced", func() { p.MaxBidiStreamNum = saved.MaxBidiStreamNum - 1 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go index 959d8257..1f1085bc 100644 --- a/internal/wire/transport_parameters.go +++ b/internal/wire/transport_parameters.go @@ -444,7 +444,7 @@ func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool { return p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal && p.InitialMaxStreamDataBidiRemote >= saved.InitialMaxStreamDataBidiRemote && p.InitialMaxStreamDataUni >= saved.InitialMaxStreamDataUni && - p.InitialMaxData == saved.InitialMaxData && + p.InitialMaxData >= saved.InitialMaxData && p.MaxBidiStreamNum >= saved.MaxBidiStreamNum && p.MaxUniStreamNum >= saved.MaxUniStreamNum && p.ActiveConnectionIDLimit == saved.ActiveConnectionIDLimit