diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 0e1a86e3..a5c0c5c3 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -19,24 +19,30 @@ import ( "github.com/lucas-clemente/quic-go/logging" . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" . "github.com/onsi/gomega" ) +type rcvdPacket struct { + hdr *logging.ExtendedHeader + frames []logging.Frame +} + type rcvdPacketTracer struct { connTracer closed chan struct{} - rcvdPackets []*logging.ExtendedHeader + rcvdPackets []rcvdPacket } func newRcvdPacketTracer() *rcvdPacketTracer { return &rcvdPacketTracer{closed: make(chan struct{})} } -func (t *rcvdPacketTracer) ReceivedPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ []logging.Frame) { - t.rcvdPackets = append(t.rcvdPackets, hdr) +func (t *rcvdPacketTracer) ReceivedPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, frames []logging.Frame) { + t.rcvdPackets = append(t.rcvdPackets, rcvdPacket{hdr: hdr, frames: frames}) } func (t *rcvdPacketTracer) Close() { close(t.closed) } -func (t *rcvdPacketTracer) getRcvdPackets() []*logging.ExtendedHeader { +func (t *rcvdPacketTracer) getRcvdPackets() []rcvdPacket { <-t.closed return t.rcvdPackets } @@ -187,11 +193,11 @@ var _ = Describe("0-RTT", func() { } // can be used to extract 0-RTT from a rcvdPacketTracer - get0RTTPackets := func(hdrs []*logging.ExtendedHeader) []protocol.PacketNumber { + get0RTTPackets := func(packets []rcvdPacket) []protocol.PacketNumber { var zeroRTTPackets []protocol.PacketNumber - for _, hdr := range hdrs { - if hdr.Type == protocol.PacketType0RTT { - zeroRTTPackets = append(zeroRTTPackets, hdr.PacketNumber) + for _, p := range packets { + if p.hdr.Type == protocol.PacketType0RTT { + zeroRTTPackets = append(zeroRTTPackets, p.hdr.PacketNumber) } } return zeroRTTPackets @@ -545,6 +551,85 @@ var _ = Describe("0-RTT", func() { Expect(get0RTTPackets(tracer.getRcvdPackets())).To(BeEmpty()) }) + DescribeTable("flow control limits", + func(addFlowControlLimit func(*quic.Config, uint64)) { + tracer := newRcvdPacketTracer() + firstConf := getQuicConfig(&quic.Config{ + AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, + Versions: []protocol.VersionNumber{version}, + }) + addFlowControlLimit(firstConf, 3) + tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) + + secondConf := getQuicConfig(&quic.Config{ + Versions: []protocol.VersionNumber{version}, + AcceptToken: func(_ net.Addr, _ *quic.Token) bool { return true }, + Tracer: newTracer(func() logging.ConnectionTracer { return tracer }), + }) + addFlowControlLimit(secondConf, 100) + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + secondConf, + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + sess, err := quic.DialAddrEarly( + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(&quic.Config{Versions: []protocol.VersionNumber{version}}), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := sess.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + written := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(written) + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }() + + Eventually(written).Should(BeClosed()) + + serverSess, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + rstr, err := serverSess.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := ioutil.ReadAll(rstr) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + Expect(serverSess.ConnectionState().TLS.Used0RTT).To(BeTrue()) + Expect(serverSess.CloseWithError(0, "")).To(Succeed()) + Eventually(sess.Context().Done()).Should(BeClosed()) + + var processedFirst bool + for _, p := range tracer.getRcvdPackets() { + for _, f := range p.frames { + if sf, ok := f.(*logging.StreamFrame); ok { + if !processedFirst { + // The first STREAM should have been sent in a 0-RTT packet. + // Due to the flow control limit, the STREAM frame was limit to the first 3 bytes. + Expect(p.hdr.Type).To(Equal(protocol.PacketType0RTT)) + Expect(sf.Length).To(BeEquivalentTo(3)) + processedFirst = true + } else { + // All other STREAM frames can only be sent after handshake completion. + Expect(p.hdr.IsLongHeader).To(BeFalse()) + Expect(sf.Offset).ToNot(BeZero()) + } + } + } + } + }, + Entry("doesn't reject 0-RTT when the server's transport stream flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialStreamReceiveWindow = limit }), + Entry("doesn't reject 0-RTT when the server's transport connection flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialConnectionReceiveWindow = limit }), + ) + It("correctly deals with 0-RTT rejections", func() { tlsConf, clientConf := dialAndReceiveSessionTicket(nil) // now dial new connection with different transport parameters @@ -662,7 +747,7 @@ var _ = Describe("0-RTT", func() { transfer0RTTData(ln, proxy.LocalPort(), clientConf, PRData, true) - Expect(tracer.rcvdPackets[0].Type).To(Equal(protocol.PacketTypeInitial)) + Expect(tracer.rcvdPackets[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) zeroRTTPackets := get0RTTPackets(tracer.getRcvdPackets()) Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 8a663289..d09673e9 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -776,7 +776,7 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(client.ConnectionState().Used0RTT).To(BeTrue()) }) - It("rejects 0-RTT, whent the transport parameters changed", func() { + It("rejects 0-RTT, when the transport parameters changed", func() { csc := mocktls.NewMockClientSessionCache(mockCtrl) var state *tls.ClientSessionState receivedSessionTicket := make(chan struct{}) @@ -810,7 +810,7 @@ var _ = Describe("Crypto Setup TLS", func() { clientHelloWrittenChan, client, clientErr, server, serverErr = handshakeWithTLSConf( clientConf, serverConf, clientRTTStats, &utils.RTTStats{}, - &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData + 1}, + &wire.TransportParameters{}, &wire.TransportParameters{InitialMaxData: initialMaxData - 1}, true, ) Expect(clientErr).ToNot(HaveOccurred()) diff --git a/internal/wire/transport_parameter_test.go b/internal/wire/transport_parameter_test.go index f4b3a80f..283fdab6 100644 --- a/internal/wire/transport_parameter_test.go +++ b/internal/wire/transport_parameter_test.go @@ -498,28 +498,48 @@ var _ = Describe("Transport Parameters", func() { Expect(p.ValidFor0RTT(saved)).To(BeTrue()) }) - It("rejects the parameters if the InitialMaxStreamDataBidiLocal changed", func() { - p.InitialMaxStreamDataBidiLocal = 0 + It("rejects the parameters if the InitialMaxStreamDataBidiLocal was reduced", func() { + p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal - 1 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) - It("rejects the parameters if the InitialMaxStreamDataBidiRemote changed", func() { - p.InitialMaxStreamDataBidiRemote = 0 + It("doesn't reject the parameters if the InitialMaxStreamDataBidiLocal was increased", func() { + p.InitialMaxStreamDataBidiLocal = saved.InitialMaxStreamDataBidiLocal + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the InitialMaxStreamDataBidiRemote was reduced", func() { + p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote - 1 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) - It("rejects the parameters if the InitialMaxStreamDataUni changed", func() { - p.InitialMaxStreamDataUni = 0 + It("doesn't reject the parameters if the InitialMaxStreamDataBidiRemote was increased", func() { + p.InitialMaxStreamDataBidiRemote = saved.InitialMaxStreamDataBidiRemote + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the InitialMaxStreamDataUni was reduced", func() { + p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni - 1 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) - It("rejects the parameters if the InitialMaxData changed", func() { - p.InitialMaxData = 0 + It("doesn't reject the parameters if the InitialMaxStreamDataUni was increased", func() { + p.InitialMaxStreamDataUni = saved.InitialMaxStreamDataUni + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the InitialMaxData was reduced", func() { + p.InitialMaxData = saved.InitialMaxData - 1 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) - It("rejects the parameters if the MaxBidiStreamNum changed", func() { - p.MaxBidiStreamNum = 0 + It("doesn't reject the parameters if the InitialMaxData was increased", func() { + p.InitialMaxData = saved.InitialMaxData + 1 + Expect(p.ValidFor0RTT(saved)).To(BeTrue()) + }) + + It("rejects the parameters if the MaxBidiStreamNum was reduced", func() { + p.MaxBidiStreamNum = saved.MaxBidiStreamNum - 1 Expect(p.ValidFor0RTT(saved)).To(BeFalse()) }) diff --git a/internal/wire/transport_parameters.go b/internal/wire/transport_parameters.go index 6bf437dc..1f1085bc 100644 --- a/internal/wire/transport_parameters.go +++ b/internal/wire/transport_parameters.go @@ -441,10 +441,10 @@ func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error // ValidFor0RTT checks if the transport parameters match those saved in the session ticket. func (p *TransportParameters) ValidFor0RTT(saved *TransportParameters) bool { - return p.InitialMaxStreamDataBidiLocal == saved.InitialMaxStreamDataBidiLocal && - p.InitialMaxStreamDataBidiRemote == saved.InitialMaxStreamDataBidiRemote && - p.InitialMaxStreamDataUni == saved.InitialMaxStreamDataUni && - p.InitialMaxData == saved.InitialMaxData && + return p.InitialMaxStreamDataBidiLocal >= saved.InitialMaxStreamDataBidiLocal && + p.InitialMaxStreamDataBidiRemote >= saved.InitialMaxStreamDataBidiRemote && + p.InitialMaxStreamDataUni >= saved.InitialMaxStreamDataUni && + p.InitialMaxData >= saved.InitialMaxData && p.MaxBidiStreamNum >= saved.MaxBidiStreamNum && p.MaxUniStreamNum >= saved.MaxUniStreamNum && p.ActiveConnectionIDLimit == saved.ActiveConnectionIDLimit