diff --git a/client.go b/client.go index f7c1ffad2..80f4410b5 100644 --- a/client.go +++ b/client.go @@ -24,6 +24,8 @@ type client struct { conn connection hostname string + receivedRetry bool + versionNegotiated bool // has the server accepted our version receivedVersionNegotiationPacket bool negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet @@ -259,6 +261,9 @@ func (c *client) dialTLS(ctx context.Context) error { return err } c.logger.Infof("Received a Retry packet. Recreating session.") + c.mutex.Lock() + c.receivedRetry = true + c.mutex.Unlock() if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil { return err } @@ -370,7 +375,13 @@ func (c *client) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remot return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID) } if hdr.IsLongHeader { - if hdr.Type != protocol.PacketTypeRetry && hdr.Type != protocol.PacketTypeHandshake { + switch hdr.Type { + case protocol.PacketTypeRetry: + if c.receivedRetry { + return nil + } + case protocol.PacketTypeHandshake: + default: return fmt.Errorf("Received unsupported packet type: %s", hdr.Type) } if protocol.ByteCount(len(packetData)) < hdr.PayloadLen { diff --git a/client_test.go b/client_test.go index f402bdfe8..49a10323d 100644 --- a/client_test.go +++ b/client_test.go @@ -693,6 +693,60 @@ var _ = Describe("Client", func() { Expect(sessions).To(BeEmpty()) }) + It("only accepts one Retry packet", func() { + config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} + sess1 := NewMockPacketHandler(mockCtrl) + sess1.EXPECT().run().Return(handshake.ErrCloseSessionForRetry) + // don't EXPECT any call to handlePacket() + sess2 := NewMockPacketHandler(mockCtrl) + run := make(chan struct{}) + sess2.EXPECT().run().Do(func() { <-run }) + sessions := make(chan *MockPacketHandler, 2) + sessions <- sess1 + sessions <- sess2 + newTLSClientSession = func( + connP connection, + _ sessionRunner, + hostnameP string, + versionP protocol.VersionNumber, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + configP *Config, + tls handshake.MintTLS, + paramsChan <-chan handshake.TransportParameters, + _ protocol.PacketNumber, + _ utils.Logger, + ) (packetHandler, error) { + return <-sessions, nil + } + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := Dial(packetConn, addr, "quic.clemente.io:1337", nil, config) + Expect(err).ToNot(HaveOccurred()) + close(done) + }() + + buf := &bytes.Buffer{} + h := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + SrcConnectionID: connID, + DestConnectionID: connID, + PacketNumberLen: protocol.PacketNumberLen1, + } + err := h.Write(buf, protocol.PerspectiveServer, protocol.VersionTLS) + Expect(err).ToNot(HaveOccurred()) + Eventually(sessions).Should(BeEmpty()) + packetConn.dataToRead <- buf.Bytes() + time.Sleep(50 * time.Millisecond) // make sure the packet is read and discarded + + // make the go routine return + close(run) + Eventually(done).Should(BeClosed()) + }) + Context("handling packets", func() { It("handles packets", func() { sess := NewMockPacketHandler(mockCtrl)