diff --git a/client.go b/client.go index 533913bf..2431d9be 100644 --- a/client.go +++ b/client.go @@ -22,6 +22,8 @@ type client struct { // If it is started with Dial, we take a packet conn as a parameter. createdPacketConn bool + use0RTT bool + packetHandlers packetHandlerManager versionNegotiated utils.AtomicBool // has the server accepted our version @@ -65,6 +67,18 @@ func DialAddr( return DialAddrContext(context.Background(), addr, tlsConf, config) } +// DialAddrEarly establishes a new 0-RTT QUIC connection to a server. +// It uses a new UDP connection and closes this connection when the QUIC session is closed. +// The hostname for SNI is taken from the given address. +func DialAddrEarly( + addr string, + tlsConf *tls.Config, + config *Config, +) (EarlySession, error) { + defer utils.Logger.WithPrefix(utils.DefaultLogger, "client").Debugf("Returning early session") + return dialAddrContext(context.Background(), addr, tlsConf, config, true) +} + // DialAddrContext establishes a new QUIC connection to a server using the provided context. // See DialAddr for details. func DialAddrContext( @@ -73,6 +87,16 @@ func DialAddrContext( tlsConf *tls.Config, config *Config, ) (Session, error) { + return dialAddrContext(ctx, addr, tlsConf, config, false) +} + +func dialAddrContext( + ctx context.Context, + addr string, + tlsConf *tls.Config, + config *Config, + use0RTT bool, +) (quicSession, error) { udpAddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err @@ -81,7 +105,7 @@ func DialAddrContext( if err != nil { return nil, err } - return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true) + return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, use0RTT, true) } // Dial establishes a new QUIC connection to a server using a net.PacketConn. @@ -96,7 +120,22 @@ func Dial( tlsConf *tls.Config, config *Config, ) (Session, error) { - return DialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config) + return dialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config, false, false) +} + +// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn. +// The same PacketConn can be used for multiple calls to Dial and Listen, +// QUIC connection IDs are used for demultiplexing the different connections. +// The host parameter is used for SNI. +// The tls.Config must define an application protocol (using NextProtos). +func DialEarly( + pconn net.PacketConn, + remoteAddr net.Addr, + host string, + tlsConf *tls.Config, + config *Config, +) (Session, error) { + return dialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config, true, false) } // DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context. @@ -109,7 +148,7 @@ func DialContext( tlsConf *tls.Config, config *Config, ) (Session, error) { - return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false) + return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false, false) } func dialContext( @@ -119,8 +158,9 @@ func dialContext( host string, tlsConf *tls.Config, config *Config, + use0RTT bool, createdPacketConn bool, -) (Session, error) { +) (quicSession, error) { if tlsConf == nil { return nil, errors.New("quic: tls.Config not set") } @@ -129,7 +169,7 @@ func dialContext( if err != nil { return nil, err } - c, err := newClient(pconn, remoteAddr, config, tlsConf, host, createdPacketConn) + c, err := newClient(pconn, remoteAddr, config, tlsConf, host, use0RTT, createdPacketConn) if err != nil { return nil, err } @@ -146,6 +186,7 @@ func newClient( config *Config, tlsConf *tls.Config, host string, + use0RTT bool, createdPacketConn bool, ) (*client, error) { if tlsConf == nil { @@ -186,6 +227,7 @@ func newClient( destConnID: destConnID, conn: &conn{pconn: pconn, currentAddr: remoteAddr}, createdPacketConn: createdPacketConn, + use0RTT: use0RTT, tlsConf: tlsConf, config: config, version: config.Versions[0], @@ -231,6 +273,13 @@ func (c *client) establishSecureConnection(ctx context.Context) error { errorChan <- err }() + // only set when we're using 0-RTT + // Otherwise, earlySessionChan will be nil. Receiving from a nil chan blocks forever. + var earlySessionChan <-chan struct{} + if c.use0RTT { + earlySessionChan = c.session.earlySessionReady() + } + select { case <-ctx.Done(): // The session will send a PeerGoingAway error to the server. @@ -238,6 +287,9 @@ func (c *client) establishSecureConnection(ctx context.Context) error { return ctx.Err() case err := <-errorChan: return err + case <-earlySessionChan: + // ready to send 0-RTT data + return nil case <-c.session.HandshakeComplete().Done(): // handshake successfully completed return nil diff --git a/client_test.go b/client_test.go index a720d7a6..709bda2c 100644 --- a/client_test.go +++ b/client_test.go @@ -257,6 +257,50 @@ var _ = Describe("Client", func() { Eventually(run).Should(BeClosed()) }) + It("returns early sessions", func() { + manager := NewMockPacketHandlerManager(mockCtrl) + manager.EXPECT().Add(gomock.Any(), gomock.Any()) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) + + readyChan := make(chan struct{}) + done := make(chan struct{}) + newClientSession = func( + _ connection, + runner sessionRunner, + _ protocol.ConnectionID, + _ protocol.ConnectionID, + _ *Config, + _ *tls.Config, + _ protocol.PacketNumber, + _ protocol.VersionNumber, + _ utils.Logger, + _ protocol.VersionNumber, + ) quicSession { + sess := NewMockQuicSession(mockCtrl) + sess.EXPECT().run().Do(func() { <-done }) + sess.EXPECT().HandshakeComplete().Return(context.Background()) + sess.EXPECT().earlySessionReady().Return(readyChan) + return sess + } + + go func() { + defer GinkgoRecover() + defer close(done) + s, err := DialEarly( + packetConn, + addr, + "localhost:1337", + tlsConf, + &Config{}, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(s).ToNot(BeNil()) + }() + Consistently(done).ShouldNot(BeClosed()) + close(readyChan) + Eventually(done).Should(BeClosed()) + }) + It("returns an error that occurs while waiting for the handshake to complete", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) diff --git a/session.go b/session.go index 0765424f..d4cd63b1 100644 --- a/session.go +++ b/session.go @@ -467,6 +467,7 @@ func (s *session) run() error { s.scheduleSending() if zeroRTTParams != nil { s.processTransportParameters(zeroRTTParams) + close(s.earlySessionReadyChan) } case closeErr := <-s.closeChan: // put the close error back into the channel, so that the run loop can receive it