diff --git a/http3/server.go b/http3/server.go index e546a930..69a186d7 100644 --- a/http3/server.go +++ b/http3/server.go @@ -14,7 +14,6 @@ import ( "time" "github.com/quic-go/quic-go" - "github.com/quic-go/quic-go/internal/handshake" "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/quicvarint" @@ -66,8 +65,9 @@ func ConfigureTLSConfig(tlsConf *tls.Config) *tls.Config { GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { // determine the ALPN from the QUIC version used proto := NextProtoH3 - if qconn, ok := ch.Conn.(handshake.ConnWithVersion); ok { - proto = versionToALPN(qconn.GetQUICVersion()) + val := ch.Context().Value(quic.QUICVersionContextKey) + if v, ok := val.(quic.VersionNumber); ok { + proto = versionToALPN(v) } config := tlsConf if tlsConf.GetConfigForClient != nil { diff --git a/http3/server_test.go b/http3/server_test.go index 93fdaa31..be1dac3a 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -28,19 +28,6 @@ import ( gmtypes "github.com/onsi/gomega/types" ) -type mockConn struct { - net.Conn - version protocol.VersionNumber -} - -func newMockConn(version protocol.VersionNumber) net.Conn { - return &mockConn{version: version} -} - -func (c *mockConn) GetQUICVersion() protocol.VersionNumber { - return c.version -} - type mockAddr struct { addr string } @@ -940,31 +927,87 @@ var _ = Describe("Server", func() { }) Context("ConfigureTLSConfig", func() { - var tlsConf *tls.Config - var ch *tls.ClientHelloInfo - - BeforeEach(func() { - tlsConf = &tls.Config{} - ch = &tls.ClientHelloInfo{} - }) - It("advertises v1 by default", func() { - tlsConf = ConfigureTLSConfig(tlsConf) - Expect(tlsConf.GetConfigForClient).NotTo(BeNil()) - - config, err := tlsConf.GetConfigForClient(ch) - Expect(err).NotTo(HaveOccurred()) - Expect(config.NextProtos).To(Equal([]string{NextProtoH3})) + conf := ConfigureTLSConfig(testdata.GetTLSConfig()) + ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) + Expect(err).ToNot(HaveOccurred()) + defer c.CloseWithError(0, "") + Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3)) }) It("advertises h3-29 for draft-29", func() { - tlsConf = ConfigureTLSConfig(tlsConf) - Expect(tlsConf.GetConfigForClient).NotTo(BeNil()) + conf := ConfigureTLSConfig(testdata.GetTLSConfig()) + ln, err := quic.ListenAddr("localhost:0", conf, &quic.Config{Versions: []quic.VersionNumber{quic.VersionDraft29}}) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3Draft29}}, nil) + Expect(err).ToNot(HaveOccurred()) + defer c.CloseWithError(0, "") + Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3Draft29)) + }) - ch.Conn = newMockConn(protocol.VersionDraft29) - config, err := tlsConf.GetConfigForClient(ch) - Expect(err).NotTo(HaveOccurred()) - Expect(config.NextProtos).To(Equal([]string{NextProtoH3Draft29})) + It("sets the GetConfigForClient callback if no tls.Config is given", func() { + var receivedConf *tls.Config + quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { + receivedConf = tlsConf + return nil, errors.New("listen err") + } + Expect(s.ListenAndServe()).To(HaveOccurred()) + Expect(receivedConf).ToNot(BeNil()) + }) + + It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() { + tlsConf := &tls.Config{ + GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { + c := testdata.GetTLSConfig() + c.NextProtos = []string{"foo", "bar"} + return c, nil + }, + } + + ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) + Expect(err).ToNot(HaveOccurred()) + defer c.CloseWithError(0, "") + Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3)) + }) + + It("works if GetConfigForClient returns a nil tls.Config", func() { + tlsConf := testdata.GetTLSConfig() + tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil } + + ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) + Expect(err).ToNot(HaveOccurred()) + defer c.CloseWithError(0, "") + Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3)) + }) + + It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() { + tlsClientConf := testdata.GetTLSConfig() + tlsClientConf.NextProtos = []string{"foo", "bar"} + tlsConf := &tls.Config{ + GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { + return tlsClientConf, nil + }, + } + + ln, err := quic.ListenAddr("localhost:0", ConfigureTLSConfig(tlsConf), &quic.Config{Versions: []quic.VersionNumber{quic.Version1}}) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + c, err := quic.DialAddr(ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) + Expect(err).ToNot(HaveOccurred()) + defer c.CloseWithError(0, "") + Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3)) + // check that the original config was not modified + Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"})) }) }) @@ -1179,15 +1222,6 @@ var _ = Describe("Server", func() { Expect(s.Close()).To(Succeed()) }) - checkGetConfigForClientVersions := func(conf *tls.Config) { - c, err := conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.VersionDraft29)}) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - ExpectWithOffset(1, c.NextProtos).To(Equal([]string{NextProtoH3Draft29})) - c, err = conf.GetConfigForClient(&tls.ClientHelloInfo{Conn: newMockConn(protocol.Version1)}) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - ExpectWithOffset(1, c.NextProtos).To(Equal([]string{NextProtoH3})) - } - It("uses the quic.Config to start the QUIC server", func() { conf := &quic.Config{HandshakeIdleTimeout: time.Nanosecond} var receivedConf *quic.Config @@ -1199,106 +1233,6 @@ var _ = Describe("Server", func() { Expect(s.ListenAndServe()).To(HaveOccurred()) Expect(receivedConf).To(Equal(conf)) }) - - It("sets the GetConfigForClient and replaces the ALPN token to the tls.Config, if the GetConfigForClient callback is not set", func() { - tlsConf := &tls.Config{ - ClientAuth: tls.RequireAndVerifyClientCert, - NextProtos: []string{"foo", "bar"}, - } - var receivedConf *tls.Config - quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { - receivedConf = tlsConf - return nil, errors.New("listen err") - } - s.TLSConfig = tlsConf - Expect(s.ListenAndServe()).To(HaveOccurred()) - Expect(receivedConf.NextProtos).To(BeEmpty()) - Expect(receivedConf.ClientAuth).To(BeZero()) - // make sure the original tls.Config was not modified - Expect(tlsConf.NextProtos).To(Equal([]string{"foo", "bar"})) - // make sure that the config returned from the GetConfigForClient callback sets the fields of the original config - conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert)) - checkGetConfigForClientVersions(receivedConf) - }) - - It("sets the GetConfigForClient callback if no tls.Config is given", func() { - var receivedConf *tls.Config - quicListenAddr = func(addr string, tlsConf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { - receivedConf = tlsConf - return nil, errors.New("listen err") - } - Expect(s.ListenAndServe()).To(HaveOccurred()) - Expect(receivedConf).ToNot(BeNil()) - checkGetConfigForClientVersions(receivedConf) - }) - - It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient", func() { - tlsConf := &tls.Config{ - GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { - return &tls.Config{ - ClientAuth: tls.RequireAndVerifyClientCert, - NextProtos: []string{"foo", "bar"}, - }, nil - }, - } - - var receivedConf *tls.Config - quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { - receivedConf = conf - return nil, errors.New("listen err") - } - s.TLSConfig = tlsConf - Expect(s.ListenAndServe()).To(HaveOccurred()) - // check that the original config was not modified - conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"})) - // check that the config returned by the GetConfigForClient callback uses the returned config - conf, err = receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(conf.ClientAuth).To(Equal(tls.RequireAndVerifyClientCert)) - checkGetConfigForClientVersions(receivedConf) - }) - - It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() { - tlsClientConf := &tls.Config{NextProtos: []string{"foo", "bar"}} - tlsConf := &tls.Config{ - GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { - return tlsClientConf, nil - }, - } - - var receivedConf *tls.Config - quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { - receivedConf = conf - return nil, errors.New("listen err") - } - s.TLSConfig = tlsConf - Expect(s.ListenAndServe()).To(HaveOccurred()) - // check that the original config was not modified - conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(conf.NextProtos).To(Equal([]string{"foo", "bar"})) - checkGetConfigForClientVersions(receivedConf) - }) - - It("works if GetConfigForClient returns a nil tls.Config", func() { - tlsConf := &tls.Config{GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { return nil, nil }} - - var receivedConf *tls.Config - quicListenAddr = func(addr string, conf *tls.Config, _ *quic.Config) (quic.EarlyListener, error) { - receivedConf = conf - return nil, errors.New("listen err") - } - s.TLSConfig = tlsConf - Expect(s.ListenAndServe()).To(HaveOccurred()) - conf, err := receivedConf.GetConfigForClient(&tls.ClientHelloInfo{}) - Expect(err).ToNot(HaveOccurred()) - Expect(conf).ToNot(BeNil()) - checkGetConfigForClientVersions(receivedConf) - }) }) It("closes gracefully", func() { diff --git a/interface.go b/interface.go index e55f258e..3eebe4a6 100644 --- a/interface.go +++ b/interface.go @@ -57,6 +57,10 @@ var ConnectionTracingKey = connTracingCtxKey{} type connTracingCtxKey struct{} +// QUICVersionContextKey can be used to find out the QUIC version of a TLS handshake from the +// context returned by tls.Config.ClientHelloInfo.Context. +var QUICVersionContextKey = handshake.QUICVersionContextKey + // Stream is the interface implemented by QUIC streams // In addition to the errors listed on the Connection, // calls to stream functions can return a StreamError if the stream is canceled. diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index ec14868c..310a34c0 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -2,6 +2,7 @@ package handshake import ( "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -20,6 +21,10 @@ import ( "github.com/quic-go/quic-go/quicvarint" ) +type quicVersionContextKey struct{} + +var QUICVersionContextKey = &quicVersionContextKey{} + // TLS unexpected_message alert const alertUnexpectedMessage uint8 = 10 @@ -64,30 +69,25 @@ const clientSessionStateRevision = 3 type conn struct { localAddr, remoteAddr net.Addr - version protocol.VersionNumber -} - -var _ ConnWithVersion = &conn{} - -func newConn(local, remote net.Addr, version protocol.VersionNumber) ConnWithVersion { - return &conn{ - localAddr: local, - remoteAddr: remote, - version: version, - } } var _ net.Conn = &conn{} -func (c *conn) Read([]byte) (int, error) { return 0, nil } -func (c *conn) Write([]byte) (int, error) { return 0, nil } -func (c *conn) Close() error { return nil } -func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } -func (c *conn) LocalAddr() net.Addr { return c.localAddr } -func (c *conn) SetReadDeadline(time.Time) error { return nil } -func (c *conn) SetWriteDeadline(time.Time) error { return nil } -func (c *conn) SetDeadline(time.Time) error { return nil } -func (c *conn) GetQUICVersion() protocol.VersionNumber { return c.version } +func newConn(local, remote net.Addr) net.Conn { + return &conn{ + localAddr: local, + remoteAddr: remote, + } +} + +func (c *conn) Read([]byte) (int, error) { return 0, nil } +func (c *conn) Write([]byte) (int, error) { return 0, nil } +func (c *conn) Close() error { return nil } +func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } +func (c *conn) LocalAddr() net.Addr { return c.localAddr } +func (c *conn) SetReadDeadline(time.Time) error { return nil } +func (c *conn) SetWriteDeadline(time.Time) error { return nil } +func (c *conn) SetDeadline(time.Time) error { return nil } type cryptoSetup struct { tlsConf *tls.Config @@ -183,7 +183,7 @@ func NewCryptoSetupClient( protocol.PerspectiveClient, version, ) - cs.conn = qtls.Client(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf) + cs.conn = qtls.Client(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf) return cs, clientHelloWritten } @@ -218,7 +218,7 @@ func NewCryptoSetupServer( version, ) cs.allow0RTT = allow0RTT - cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf) + cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf) return cs } @@ -307,7 +307,7 @@ func (h *cryptoSetup) RunHandshake() { handshakeErrChan := make(chan error, 1) go func() { defer close(h.handshakeDone) - if err := h.conn.Handshake(); err != nil { + if err := h.conn.HandshakeContext(context.WithValue(context.Background(), QUICVersionContextKey, h.version)); err != nil { handshakeErrChan <- err return } diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index e7baea90..f80b6e0e 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -3,7 +3,6 @@ package handshake import ( "errors" "io" - "net" "time" "github.com/quic-go/quic-go/internal/protocol" @@ -93,10 +92,3 @@ type CryptoSetup interface { Get0RTTSealer() (LongHeaderSealer, error) Get1RTTSealer() (ShortHeaderSealer, error) } - -// ConnWithVersion is the connection used in the ClientHelloInfo. -// It can be used to determine the QUIC version in use. -type ConnWithVersion interface { - net.Conn - GetQUICVersion() protocol.VersionNumber -}