forked from quic-go/quic-go
fix the server's 0-RTT rejection logic when using GetConfigForClient (#4550)
This commit is contained in:
21
internal/qtls/conn.go
Normal file
21
internal/qtls/conn.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package qtls
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type conn struct {
|
||||
localAddr, remoteAddr net.Addr
|
||||
}
|
||||
|
||||
var _ net.Conn = &conn{}
|
||||
|
||||
func (c *conn) Read([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Write([]byte) (int, error) { return 0, nil }
|
||||
func (c *conn) Close() error { return nil }
|
||||
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
|
||||
func (c *conn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
|
||||
func (c *conn) SetDeadline(time.Time) error { return nil }
|
||||
@@ -4,20 +4,23 @@ import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
)
|
||||
|
||||
func SetupConfigForServer(qconf *tls.QUICConfig, getData func() []byte, handleSessionTicket func([]byte, bool) bool) {
|
||||
conf := qconf.TLSConfig
|
||||
|
||||
func SetupConfigForServer(
|
||||
conf *tls.Config,
|
||||
localAddr, remoteAddr net.Addr,
|
||||
getData func() []byte,
|
||||
handleSessionTicket func([]byte, bool) bool,
|
||||
) *tls.Config {
|
||||
// Workaround for https://github.com/golang/go/issues/60506.
|
||||
// This initializes the session tickets _before_ cloning the config.
|
||||
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})
|
||||
|
||||
conf = conf.Clone()
|
||||
conf.MinVersion = tls.VersionTLS13
|
||||
qconf.TLSConfig = conf
|
||||
|
||||
// add callbacks to save transport parameters into the session ticket
|
||||
origWrapSession := conf.WrapSession
|
||||
@@ -58,6 +61,29 @@ func SetupConfigForServer(qconf *tls.QUICConfig, getData func() []byte, handleSe
|
||||
|
||||
return state, nil
|
||||
}
|
||||
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
|
||||
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
|
||||
// that allows the caller to get the local and the remote address.
|
||||
if conf.GetConfigForClient != nil {
|
||||
gcfc := conf.GetConfigForClient
|
||||
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
|
||||
c, err := gcfc(info)
|
||||
if c != nil {
|
||||
// We're returning a tls.Config here, so we need to apply this recursively.
|
||||
c = SetupConfigForServer(c, localAddr, remoteAddr, getData, handleSessionTicket)
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
}
|
||||
if conf.GetCertificate != nil {
|
||||
gc := conf.GetCertificate
|
||||
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
|
||||
return gc(info)
|
||||
}
|
||||
}
|
||||
return conf
|
||||
}
|
||||
|
||||
func SetupConfigForClient(
|
||||
|
||||
@@ -2,6 +2,8 @@ package qtls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"reflect"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
|
||||
@@ -41,13 +43,86 @@ var _ = Describe("interface go crypto/tls", func() {
|
||||
})
|
||||
|
||||
Context("setting up a tls.Config for the server", func() {
|
||||
var (
|
||||
local = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 42}
|
||||
remote = &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
|
||||
)
|
||||
|
||||
It("sets the minimum TLS version to TLS 1.3", func() {
|
||||
orig := &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
conf := &tls.QUICConfig{TLSConfig: orig}
|
||||
SetupConfigForServer(conf, nil, nil)
|
||||
Expect(conf.TLSConfig.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
||||
conf := SetupConfigForServer(orig, local, remote, nil, nil)
|
||||
Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
||||
// check that the original config wasn't modified
|
||||
Expect(orig.MinVersion).To(BeEquivalentTo(tls.VersionTLS12))
|
||||
})
|
||||
|
||||
It("wraps GetCertificate", func() {
|
||||
var localAddr, remoteAddr net.Addr
|
||||
tlsConf := &tls.Config{
|
||||
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
localAddr = info.Conn.LocalAddr()
|
||||
remoteAddr = info.Conn.RemoteAddr()
|
||||
return &tls.Certificate{}, nil
|
||||
},
|
||||
}
|
||||
conf := SetupConfigForServer(tlsConf, local, remote, nil, nil)
|
||||
_, err := conf.GetCertificate(&tls.ClientHelloInfo{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(localAddr).To(Equal(local))
|
||||
Expect(remoteAddr).To(Equal(remote))
|
||||
})
|
||||
|
||||
It("wraps GetConfigForClient", func() {
|
||||
var localAddr, remoteAddr net.Addr
|
||||
tlsConf := SetupConfigForServer(
|
||||
&tls.Config{
|
||||
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
localAddr = info.Conn.LocalAddr()
|
||||
remoteAddr = info.Conn.RemoteAddr()
|
||||
return &tls.Config{}, nil
|
||||
},
|
||||
},
|
||||
local,
|
||||
remote,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(localAddr).To(Equal(local))
|
||||
Expect(remoteAddr).To(Equal(remote))
|
||||
Expect(conf).ToNot(BeNil())
|
||||
Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
||||
})
|
||||
|
||||
It("wraps GetConfigForClient, recursively", func() {
|
||||
var localAddr, remoteAddr net.Addr
|
||||
tlsConf := &tls.Config{}
|
||||
var innerConf *tls.Config
|
||||
getCert := func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { //nolint:unparam
|
||||
localAddr = info.Conn.LocalAddr()
|
||||
remoteAddr = info.Conn.RemoteAddr()
|
||||
return &tls.Certificate{}, nil
|
||||
}
|
||||
tlsConf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
innerConf = tlsConf.Clone()
|
||||
// set the MaxVersion, so we can check that quic-go doesn't overwrite the user's config
|
||||
innerConf.MaxVersion = tls.VersionTLS12
|
||||
innerConf.GetCertificate = getCert
|
||||
return innerConf, nil
|
||||
}
|
||||
tlsConf = SetupConfigForServer(tlsConf, local, remote, nil, nil)
|
||||
conf, err := tlsConf.GetConfigForClient(&tls.ClientHelloInfo{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(conf).ToNot(BeNil())
|
||||
Expect(conf.MinVersion).To(BeEquivalentTo(tls.VersionTLS13))
|
||||
_, err = conf.GetCertificate(&tls.ClientHelloInfo{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(localAddr).To(Equal(local))
|
||||
Expect(remoteAddr).To(Equal(remote))
|
||||
// make sure that the tls.Config returned by GetConfigForClient isn't modified
|
||||
Expect(reflect.ValueOf(innerConf.GetCertificate).Pointer() == reflect.ValueOf(getCert).Pointer()).To(BeTrue())
|
||||
Expect(innerConf.MaxVersion).To(BeEquivalentTo(tls.VersionTLS12))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user