Merge pull request #577 from lucas-clemente/fix-565

add a quic.Config option for QUIC versions
This commit is contained in:
Marten Seemann
2017-05-09 18:30:04 +08:00
committed by GitHub
18 changed files with 209 additions and 176 deletions

View File

@@ -2,4 +2,5 @@
## v0.6.0 (unreleased)
- Add a `quic.Config` option for QUIC versions
- Various bugfixes

View File

@@ -47,12 +47,13 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
return nil, err
}
clientConfig := populateClientConfig(config)
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
connectionID: connID,
hostname: hostname,
config: config,
version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default
config: clientConfig,
version: clientConfig.Versions[0],
}
c.connStateChangeOrErrCond.L = &c.mutex
@@ -67,6 +68,19 @@ func Dial(pconn net.PacketConn, remoteAddr net.Addr, host string, config *Config
return c.establishConnection()
}
func populateClientConfig(config *Config) *Config {
versions := config.Versions
if len(versions) == 0 {
versions = protocol.SupportedVersions
}
return &Config{
TLSConfig: config.TLSConfig,
ConnState: config.ConnState,
Versions: versions,
}
}
// DialAddr establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address.
func DialAddr(addr string, config *Config) (Session, error) {
@@ -191,20 +205,20 @@ func (c *client) handlePacketWithVersionFlag(hdr *PublicHeader) error {
}
}
ok, highestSupportedVersion := protocol.HighestSupportedVersion(hdr.SupportedVersions)
if !ok {
newVersion := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if newVersion == protocol.VersionUnsupported {
return qerr.InvalidVersion
}
// switch to negotiated version
c.version = highestSupportedVersion
c.version = newVersion
c.connState = ConnStateVersionNegotiated
var err error
c.connectionID, err = utils.GenerateConnectionID()
if err != nil {
return err
}
utils.Infof("Switching to QUIC version %d. New connection ID: %x", highestSupportedVersion, c.connectionID)
utils.Infof("Switching to QUIC version %d. New connection ID: %x", newVersion, c.connectionID)
c.session.Close(errCloseSessionForNewVersion)
err = c.createNewSession(hdr.SupportedVersions)

View File

@@ -2,7 +2,6 @@ package quic
import (
"bytes"
"encoding/binary"
"errors"
"net"
"reflect"
@@ -35,6 +34,7 @@ var _ = Describe("Client", func() {
versionNegotiateConnStateCalled = true
}
},
Versions: []protocol.VersionNumber{protocol.SupportedVersions[0], 77, 78},
}
addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
sess = &mockSession{connectionID: 0x1337}
@@ -42,7 +42,7 @@ var _ = Describe("Client", func() {
config: config,
connectionID: 0x1337,
session: sess,
version: protocol.Version36,
version: protocol.SupportedVersions[0],
conn: &conn{pconn: packetConn, currentAddr: addr},
}
})
@@ -57,7 +57,6 @@ var _ = Describe("Client", func() {
Context("Dialing", func() {
It("creates a new client", func() {
packetConn.dataToRead = []byte{0x0, 0x1, 0x0}
var err error
sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
Expect(err).ToNot(HaveOccurred())
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil())
@@ -65,6 +64,11 @@ var _ = Describe("Client", func() {
sess.Close(nil)
})
It("uses all supported versions, if none are specified in the quic.Config", func() {
c := populateClientConfig(&Config{})
Expect(c.Versions).To(Equal(protocol.SupportedVersions))
})
It("errors when receiving an invalid first packet from the server", func() {
packetConn.dataToRead = []byte{0xff}
sess, err := Dial(packetConn, addr, "quic.clemente.io:1337", config)
@@ -198,22 +202,6 @@ var _ = Describe("Client", func() {
})
Context("version negotiation", func() {
getVersionNegotiation := func(versions []protocol.VersionNumber) []byte {
oldVersionNegotiationPacket := composeVersionNegotiation(0x1337)
oldSupportVersionTags := protocol.SupportedVersionsAsTags
var b bytes.Buffer
for _, v := range versions {
s := make([]byte, 4)
binary.LittleEndian.PutUint32(s, protocol.VersionNumberToTag(v))
b.Write(s)
}
protocol.SupportedVersionsAsTags = b.Bytes()
packet := composeVersionNegotiation(cl.connectionID)
protocol.SupportedVersionsAsTags = oldSupportVersionTags
Expect(composeVersionNegotiation(0x1337)).To(Equal(oldVersionNegotiationPacket))
return packet
}
It("recognizes that a packet without VersionFlag means that the server accepted the suggested version", func() {
ph := PublicHeader{
PacketNumber: 1,
@@ -230,11 +218,13 @@ var _ = Describe("Client", func() {
})
It("changes the version after receiving a version negotiation packet", func() {
newVersion := protocol.Version35
newVersion := protocol.VersionNumber(77)
Expect(config.Versions).To(ContainElement(newVersion))
Expect(newVersion).ToNot(Equal(cl.version))
Expect(sess.packetCount).To(BeZero())
cl.connectionID = 0x1337
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{newVersion}))
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion}))
Expect(err).ToNot(HaveOccurred())
Expect(cl.version).To(Equal(newVersion))
Expect(cl.connState).To(Equal(ConnStateVersionNegotiated))
Eventually(func() bool { return versionNegotiateConnStateCalled }).Should(BeTrue())
@@ -246,19 +236,33 @@ var _ = Describe("Client", func() {
Expect(sess.packetCount).To(BeZero())
// if the version negotiation packet was passed to the new session, it would end up as an undecryptable packet there
Expect(cl.session.(*session).undecryptablePackets).To(BeEmpty())
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(cl.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{35}))
Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(cl.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{newVersion}))
})
It("errors if no matching version is found", func() {
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1}))
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{1}))
Expect(err).To(MatchError(qerr.InvalidVersion))
})
It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() {
v := protocol.SupportedVersions[1]
Expect(v).ToNot(Equal(cl.version))
Expect(config.Versions).ToNot(ContainElement(v))
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{v}))
Expect(err).To(MatchError(qerr.InvalidVersion))
})
It("changes to the version preferred by the quic.Config", func() {
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{config.Versions[2], config.Versions[1]}))
Expect(err).ToNot(HaveOccurred())
Expect(cl.version).To(Equal(config.Versions[1]))
})
It("ignores delayed version negotiation packets", func() {
// if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test
cl.connState = ConnStateVersionNegotiated
Expect(sess.packetCount).To(BeZero())
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{1}))
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{1}))
Expect(err).ToNot(HaveOccurred())
Expect(cl.connState).To(Equal(ConnStateVersionNegotiated))
Expect(sess.packetCount).To(BeZero())
@@ -267,7 +271,7 @@ var _ = Describe("Client", func() {
It("drops version negotiation packets that contain the offered version", func() {
ver := cl.version
err := cl.handlePacket(nil, getVersionNegotiation([]protocol.VersionNumber{ver}))
err := cl.handlePacket(nil, composeVersionNegotiation(0x1337, []protocol.VersionNumber{ver}))
Expect(err).ToNot(HaveOccurred())
Expect(cl.version).To(Equal(ver))
})

View File

@@ -7,6 +7,7 @@ import (
"net"
"net/http"
"runtime"
"strconv"
"sync"
"sync/atomic"
"time"
@@ -39,6 +40,8 @@ type Server struct {
listenerMutex sync.Mutex
listener quic.Listener
supportedVersionsAsString string
}
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
@@ -79,6 +82,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error {
s.listenerMutex.Unlock()
return errors.New("ListenAndServe may only be called once")
}
config := quic.Config{
TLSConfig: tlsConfig,
ConnState: func(session quic.Session, connState quic.ConnState) {
@@ -87,7 +91,9 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error {
s.handleHeaderStream(sess)
}
},
Versions: protocol.SupportedVersions,
}
var ln quic.Listener
var err error
if conn == nil {
@@ -267,8 +273,17 @@ func (s *Server) SetQuicHeaders(hdr http.Header) error {
atomic.StoreUint32(&s.port, port)
}
if s.supportedVersionsAsString == "" {
for i, v := range protocol.SupportedVersions {
s.supportedVersionsAsString += strconv.Itoa(int(v))
if i != len(protocol.SupportedVersions)-1 {
s.supportedVersionsAsString += ","
}
}
}
hdr.Add("Alternate-Protocol", fmt.Sprintf("%d:quic", port))
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, protocol.SupportedVersionsAsString))
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
return nil
}

View File

@@ -270,7 +270,7 @@ func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool {
return false
}
ver := protocol.VersionTagToNumber(verTag)
if !protocol.IsSupportedVersion(ver) {
if !protocol.IsSupportedVersion(protocol.SupportedVersions, ver) {
ver = protocol.VersionUnsupported
}
if ver != negotiatedVersion {

View File

@@ -71,7 +71,7 @@ func (m *mockCertManager) Verify(hostname string) error {
return m.verifyError
}
var _ = Describe("Crypto setup", func() {
var _ = Describe("Client Crypto Setup", func() {
var cs *cryptoSetupClient
var certManager *mockCertManager
var stream *mockStream
@@ -81,7 +81,7 @@ var _ = Describe("Crypto setup", func() {
BeforeEach(func() {
shloMap = map[Tag][]byte{
TagPUBS: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f},
TagVER: protocol.SupportedVersionsAsTags,
TagVER: []byte{},
}
keyDerivation := func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error) {
keyDerivationCalledWith = &keyDerivationValues{

View File

@@ -24,10 +24,12 @@ type KeyExchangeFunction func() crypto.KeyExchange
type cryptoSetupServer struct {
connID protocol.ConnectionID
sourceAddr []byte
version protocol.VersionNumber
scfg *ServerConfig
diversificationNonce []byte
version protocol.VersionNumber
supportedVersions []protocol.VersionNumber
nullAEAD crypto.AEAD
secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD
@@ -61,12 +63,14 @@ func NewCryptoSetup(
scfg *ServerConfig,
cryptoStream io.ReadWriter,
connectionParametersManager ConnectionParametersManager,
supportedVersions []protocol.VersionNumber,
aeadChanged chan protocol.EncryptionLevel,
) (CryptoSetup, error) {
return &cryptoSetupServer{
connID: connID,
sourceAddr: sourceAddr,
version: version,
supportedVersions: supportedVersions,
scfg: scfg,
keyDerivation: crypto.DeriveKeysAESGCM,
keyExchange: getEphermalKEX,
@@ -127,7 +131,7 @@ func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]
verTag := binary.LittleEndian.Uint32(verSlice)
ver := protocol.VersionTagToNumber(verTag)
// If the client's preferred version is not the version we are currently speaking, then the client went through a version negotiation. In this case, we need to make sure that we actually do not support this version and that it wasn't a downgrade attack.
if ver != h.version && protocol.IsSupportedVersion(ver) {
if ver != h.version && protocol.IsSupportedVersion(h.supportedVersions, ver) {
return false, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
}
@@ -397,9 +401,13 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T
return nil, err
}
// add crypto parameters
verTag := &bytes.Buffer{}
for _, v := range h.supportedVersions {
utils.WriteUint32(verTag, protocol.VersionNumberToTag(v))
}
replyMap[TagPUBS] = ephermalKex.PublicKey()
replyMap[TagSNO] = serverNonce
replyMap[TagVER] = protocol.SupportedVersionsAsTags
replyMap[TagVER] = verTag.Bytes()
// note that the SHLO *has* to fit into one packet
var reply bytes.Buffer

View File

@@ -9,6 +9,7 @@ import (
"github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
@@ -140,22 +141,23 @@ func (s mockStkSource) VerifyToken(sourceAddr []byte, token []byte) error {
return nil
}
var _ = Describe("Crypto setup", func() {
var _ = Describe("Server Crypto Setup", func() {
var (
kex *mockKEX
signer *mockSigner
scfg *ServerConfig
cs *cryptoSetupServer
stream *mockStream
cpm ConnectionParametersManager
aeadChanged chan protocol.EncryptionLevel
nonce32 []byte
versionTag []byte
sourceAddr []byte
validSTK []byte
aead []byte
kexs []byte
version protocol.VersionNumber
kex *mockKEX
signer *mockSigner
scfg *ServerConfig
cs *cryptoSetupServer
stream *mockStream
cpm ConnectionParametersManager
aeadChanged chan protocol.EncryptionLevel
nonce32 []byte
versionTag []byte
sourceAddr []byte
validSTK []byte
aead []byte
kexs []byte
version protocol.VersionNumber
supportedVersions []protocol.VersionNumber
)
BeforeEach(func() {
@@ -179,8 +181,9 @@ var _ = Describe("Crypto setup", func() {
Expect(err).NotTo(HaveOccurred())
scfg.stkSource = &mockStkSource{}
version = protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
supportedVersions = []protocol.VersionNumber{version, 98, 99}
cpm = NewConnectionParamatersManager(protocol.PerspectiveServer, protocol.VersionWhatever)
csInt, err := NewCryptoSetup(protocol.ConnectionID(42), sourceAddr, version, scfg, stream, cpm, aeadChanged)
csInt, err := NewCryptoSetup(protocol.ConnectionID(42), sourceAddr, version, scfg, stream, cpm, supportedVersions, aeadChanged)
Expect(err).NotTo(HaveOccurred())
cs = csInt.(*cryptoSetupServer)
cs.keyDerivation = mockKeyDerivation
@@ -275,7 +278,11 @@ var _ = Describe("Crypto setup", func() {
Expect(response).To(HavePrefix("SHLO"))
Expect(response).To(ContainSubstring("ephermal pub"))
Expect(response).To(ContainSubstring("SNO\x00"))
Expect(response).To(ContainSubstring(string(protocol.SupportedVersionsAsTags)))
for _, v := range supportedVersions {
b := &bytes.Buffer{}
utils.WriteUint32(b, protocol.VersionNumberToTag(v))
Expect(response).To(ContainSubstring(string(b.Bytes())))
}
Expect(cs.secureAEAD).ToNot(BeNil())
Expect(cs.secureAEAD.(*mockAEAD).forwardSecure).To(BeFalse())
Expect(cs.secureAEAD.(*mockAEAD).sharedSecret).To(Equal([]byte("shared key")))
@@ -391,8 +398,8 @@ var _ = Describe("Crypto setup", func() {
})
It("detects version downgrade attacks", func() {
highestSupportedVersion := protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
lowestSupportedVersion := protocol.SupportedVersions[0]
highestSupportedVersion := supportedVersions[len(protocol.SupportedVersions)-1]
lowestSupportedVersion := supportedVersions[0]
Expect(highestSupportedVersion).ToNot(Equal(lowestSupportedVersion))
cs.version = highestSupportedVersion
b := make([]byte, 4)
@@ -406,7 +413,7 @@ var _ = Describe("Crypto setup", func() {
It("accepts a non-matching version tag in the CHLO, if it is an unsupported version", func() {
supportedVersion := protocol.SupportedVersions[0]
unsupportedVersion := supportedVersion + 1000
Expect(protocol.IsSupportedVersion(unsupportedVersion)).To(BeFalse())
Expect(protocol.IsSupportedVersion(supportedVersions, unsupportedVersion)).To(BeFalse())
cs.version = supportedVersion
b := make([]byte, 4)
binary.LittleEndian.PutUint32(b, protocol.VersionNumberToTag(unsupportedVersion))

View File

@@ -43,7 +43,8 @@ func init() {
}
var _ = Describe("Chrome tests", func() {
It("does not work with mismatching versions", func() {
// test disabled since it doesn't work with the configurable QUIC version in the server
PIt("does not work with mismatching versions", func() {
versionForUs := protocol.SupportedVersions[0]
versionForChrome := protocol.SupportedVersions[1]

View File

@@ -63,6 +63,10 @@ type Config struct {
// If this field is not set, the Dial functions will return only when the connection is forward secure.
// Callbacks have to be thread-safe, since they might be called in separate goroutines.
ConnState ConnStateCallback
// The QUIC versions that can be negotiated.
// If not set, it uses all versions available.
// Warning: This API should not be considered stable and will change soon.
Versions []protocol.VersionNumber
}
// A Listener for incoming QUIC connections

View File

@@ -1,11 +1,5 @@
package protocol
import (
"bytes"
"encoding/binary"
"strconv"
)
// VersionNumber is a version number as int
type VersionNumber int
@@ -14,22 +8,16 @@ const (
Version35 VersionNumber = 35 + iota
Version36
Version37
VersionWhatever = 0 // for when the version doesn't matter
VersionUnsupported = -1
VersionWhatever VersionNumber = 0 // for when the version doesn't matter
VersionUnsupported VersionNumber = -1
)
// SupportedVersions lists the versions that the server supports
// must be in sorted order
// must be in sorted descending order
var SupportedVersions = []VersionNumber{
Version35, Version36, Version37,
Version37, Version36, Version35,
}
// SupportedVersionsAsTags is needed for the SHLO crypto message
var SupportedVersionsAsTags []byte
// SupportedVersionsAsString is needed for the Alt-Scv HTTP header
var SupportedVersionsAsString string
// VersionNumberToTag maps version numbers ('32') to tags ('Q032')
func VersionNumberToTag(vn VersionNumber) uint32 {
v := uint32(vn)
@@ -42,8 +30,8 @@ func VersionTagToNumber(v uint32) VersionNumber {
}
// IsSupportedVersion returns true if the server supports this version
func IsSupportedVersion(v VersionNumber) bool {
for _, t := range SupportedVersions {
func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool {
for _, t := range supported {
if t == v {
return true
}
@@ -51,41 +39,17 @@ func IsSupportedVersion(v VersionNumber) bool {
return false
}
// HighestSupportedVersion finds the highest version number that is both present in other and in SupportedVersions
// the versions in other do not need to be ordered
// it returns true and the version number, if there is one, otherwise false
func HighestSupportedVersion(other []VersionNumber) (bool, VersionNumber) {
var otherSupported []VersionNumber
for _, ver := range other {
if ver != VersionUnsupported {
otherSupported = append(otherSupported, ver)
}
}
for i := len(SupportedVersions) - 1; i >= 0; i-- {
for _, ver := range otherSupported {
if ver == SupportedVersions[i] {
return true, ver
// ChooseSupportedVersion finds the best version in the overlap of ours and theirs
// ours is a slice of versions that we support, sorted by our preference (descending)
// theirs is a slice of versions offered by the peer. The order does not matter
// if no suitable version is found, it returns VersionUnsupported
func ChooseSupportedVersion(ours, theirs []VersionNumber) VersionNumber {
for _, ourVer := range ours {
for _, theirVer := range theirs {
if ourVer == theirVer {
return ourVer
}
}
}
return false, 0
}
func init() {
var b bytes.Buffer
for _, v := range SupportedVersions {
s := make([]byte, 4)
binary.LittleEndian.PutUint32(s, VersionNumberToTag(v))
b.Write(s)
}
SupportedVersionsAsTags = b.Bytes()
for i := len(SupportedVersions) - 1; i >= 0; i-- {
SupportedVersionsAsString += strconv.Itoa(int(SupportedVersions[i]))
if i != 0 {
SupportedVersionsAsString += ","
}
}
return VersionUnsupported
}

View File

@@ -14,59 +14,38 @@ var _ = Describe("Version", func() {
Expect(VersionNumberToTag(VersionNumber(123))).To(Equal(uint32('Q' + '1'<<8 + '2'<<16 + '3'<<24)))
})
It("has proper tag list", func() {
Expect(SupportedVersionsAsTags).To(Equal([]byte("Q035Q036Q037")))
})
It("has proper version list", func() {
Expect(SupportedVersionsAsString).To(Equal("37,36,35"))
})
It("recognizes supported versions", func() {
Expect(IsSupportedVersion(0)).To(BeFalse())
Expect(IsSupportedVersion(SupportedVersions[0])).To(BeTrue())
Expect(IsSupportedVersion(SupportedVersions, 0)).To(BeFalse())
Expect(IsSupportedVersion(SupportedVersions, SupportedVersions[0])).To(BeTrue())
Expect(IsSupportedVersion(SupportedVersions, SupportedVersions[len(SupportedVersions)-1])).To(BeTrue())
})
It("has supported versions in sorted order", func() {
for i := 0; i < len(SupportedVersions)-1; i++ {
Expect(SupportedVersions[i]).To(BeNumerically("<", SupportedVersions[i+1]))
Expect(SupportedVersions[i]).To(BeNumerically(">", SupportedVersions[i+1]))
}
})
Context("highest supported version", func() {
var initialSupportedVersions []VersionNumber
BeforeEach(func() {
initialSupportedVersions = make([]VersionNumber, len(SupportedVersions))
copy(initialSupportedVersions, SupportedVersions)
})
AfterEach(func() {
SupportedVersions = initialSupportedVersions
})
It("finds the supported version", func() {
SupportedVersions = []VersionNumber{1, 2, 3}
other := []VersionNumber{3, 4, 5, 6}
found, ver := HighestSupportedVersion(other)
Expect(found).To(BeTrue())
Expect(ver).To(Equal(VersionNumber(3)))
supportedVersions := []VersionNumber{1, 2, 3}
other := []VersionNumber{6, 5, 4, 3}
Expect(ChooseSupportedVersion(supportedVersions, other)).To(Equal(VersionNumber(3)))
})
It("picks the highest supported version", func() {
SupportedVersions = []VersionNumber{1, 2, 3, 6, 7}
It("picks the preferred version", func() {
supportedVersions := []VersionNumber{2, 1, 3}
other := []VersionNumber{3, 6, 1, 8, 2, 10}
found, ver := HighestSupportedVersion(other)
Expect(found).To(BeTrue())
Expect(ver).To(Equal(VersionNumber(6)))
Expect(ChooseSupportedVersion(supportedVersions, other)).To(Equal(VersionNumber(2)))
})
It("handles empty inputs", func() {
SupportedVersions = []VersionNumber{101, 102}
Expect(HighestSupportedVersion([]VersionNumber{})).To(BeFalse())
SupportedVersions = []VersionNumber{}
Expect(HighestSupportedVersion([]VersionNumber{1, 2})).To(BeFalse())
Expect(HighestSupportedVersion([]VersionNumber{})).To(BeFalse())
supportedVersions := []VersionNumber{102, 101}
Expect(ChooseSupportedVersion(supportedVersions, nil)).To(Equal(VersionUnsupported))
Expect(ChooseSupportedVersion(supportedVersions, []VersionNumber{})).To(Equal(VersionUnsupported))
supportedVersions = []VersionNumber{}
Expect(ChooseSupportedVersion(supportedVersions, []VersionNumber{1, 2})).To(Equal(VersionUnsupported))
Expect(ChooseSupportedVersion(supportedVersions, []VersionNumber{})).To(Equal(VersionUnsupported))
})
})
})

View File

@@ -196,9 +196,6 @@ func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*Pub
break
}
v := protocol.VersionTagToNumber(versionTag)
if !protocol.IsSupportedVersion(v) {
v = protocol.VersionUnsupported
}
header.SupportedVersions = append(header.SupportedVersions, v)
}
}

View File

@@ -92,7 +92,7 @@ var _ = Describe("Public Header", func() {
}
It("parses version negotiation packets sent by the server", func() {
b := bytes.NewReader(composeVersionNegotiation(0x1337))
b := bytes.NewReader(composeVersionNegotiation(0x1337, protocol.SupportedVersions))
hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.VersionFlag).To(BeTrue())
@@ -111,21 +111,21 @@ var _ = Describe("Public Header", func() {
Expect(b.Len()).To(BeZero())
})
It("sets version numbers to unsupported, if we don't support them", func() {
It("reads version negotiation packets containing unsupported versions", func() {
data := []byte{0x9, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c}
data = appendVersion(data, 1) // unsupported version
data = appendVersion(data, protocol.SupportedVersions[0])
data = appendVersion(data, 1337) // unsupported version
data = appendVersion(data, 99) // unsupported version
b := bytes.NewReader(data)
hdr, err := ParsePublicHeader(b, protocol.PerspectiveServer)
Expect(err).ToNot(HaveOccurred())
Expect(hdr.VersionFlag).To(BeTrue())
Expect(hdr.SupportedVersions).To(Equal([]protocol.VersionNumber{protocol.VersionUnsupported, protocol.SupportedVersions[0], protocol.VersionUnsupported}))
Expect(hdr.SupportedVersions).To(Equal([]protocol.VersionNumber{1, protocol.SupportedVersions[0], 99}))
Expect(b.Len()).To(BeZero())
})
It("errors on invalid version tags", func() {
data := composeVersionNegotiation(0x1337)
data := composeVersionNegotiation(0x1337, protocol.SupportedVersions)
data = append(data, []byte{0x13, 0x37}...)
b := bytes.NewReader(data)
_, err := ParsePublicHeader(b, protocol.PerspectiveServer)

View File

@@ -34,7 +34,7 @@ type server struct {
sessionsMutex sync.RWMutex
deleteClosedSessionsAfter time.Duration
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error)
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback, supportedVersions []protocol.VersionNumber) (packetHandler, error)
}
var _ Listener = &server{}
@@ -68,7 +68,7 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) {
return &server{
conn: conn,
config: config,
config: populateServerConfig(config),
certChain: certChain,
scfg: scfg,
sessions: map[protocol.ConnectionID]packetHandler{},
@@ -77,6 +77,19 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) {
}, nil
}
func populateServerConfig(config *Config) *Config {
versions := config.Versions
if len(versions) == 0 {
versions = protocol.SupportedVersions
}
return &Config{
TLSConfig: config.TLSConfig,
ConnState: config.ConnState,
Versions: versions,
}
}
// Listen listens on an existing PacketConn
func (s *server) Serve() error {
for {
@@ -152,18 +165,18 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
// a session is only created once the client sent a supported version
// if we receive a packet for a connection that already has session, it's probably an old packet that was sent by the client before the version was negotiated
// it is safe to drop it
if ok && hdr.VersionFlag && !protocol.IsSupportedVersion(hdr.VersionNumber) {
if ok && hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.VersionNumber) {
return nil
}
// Send Version Negotiation Packet if the client is speaking a different protocol version
if hdr.VersionFlag && !protocol.IsSupportedVersion(hdr.VersionNumber) {
if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.VersionNumber) {
// drop packets that are too small to be valid first packets
if len(packet) < protocol.ClientHelloMinimumSize+len(hdr.Raw) {
return errors.New("dropping small packet with unknown version")
}
utils.Infof("Client offered version %d, sending VersionNegotiationPacket", hdr.VersionNumber)
_, err = pconn.WriteTo(composeVersionNegotiation(hdr.ConnectionID), remoteAddr)
_, err = pconn.WriteTo(composeVersionNegotiation(hdr.ConnectionID, s.config.Versions), remoteAddr)
return err
}
@@ -173,7 +186,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
return err
}
version := hdr.VersionNumber
if !protocol.IsSupportedVersion(version) {
if !protocol.IsSupportedVersion(s.config.Versions, version) {
return errors.New("Server BUG: negotiated version not supported")
}
@@ -184,6 +197,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
hdr.ConnectionID,
s.scfg,
s.cryptoChangeCallback,
s.config.Versions,
)
if err != nil {
return err
@@ -240,17 +254,19 @@ func (s *server) removeConnection(id protocol.ConnectionID) {
})
}
func composeVersionNegotiation(connectionID protocol.ConnectionID) []byte {
func composeVersionNegotiation(connectionID protocol.ConnectionID, versions []protocol.VersionNumber) []byte {
fullReply := &bytes.Buffer{}
responsePublicHeader := PublicHeader{
ConnectionID: connectionID,
PacketNumber: 1,
VersionFlag: true,
}
err := responsePublicHeader.Write(fullReply, protocol.Version35, protocol.PerspectiveServer)
err := responsePublicHeader.Write(fullReply, protocol.VersionWhatever, protocol.PerspectiveServer)
if err != nil {
utils.Errorf("error composing version negotiation packet: %s", err.Error())
}
fullReply.Write(protocol.SupportedVersionsAsTags)
for _, v := range versions {
utils.WriteUint32(fullReply, protocol.VersionNumberToTag(v))
}
return fullReply.Bytes()
}

View File

@@ -2,6 +2,7 @@ package quic
import (
"bytes"
"crypto/tls"
"errors"
"net"
"time"
@@ -55,7 +56,7 @@ func (s *mockSession) RemoteAddr() net.Addr {
var _ Session = &mockSession{}
func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) {
func newMockSession(_ connection, _ protocol.VersionNumber, connectionID protocol.ConnectionID, _ *handshake.ServerConfig, _ cryptoChangeCallback, _ []protocol.VersionNumber) (packetHandler, error) {
return &mockSession{
connectionID: connectionID,
stopRunLoop: make(chan struct{}),
@@ -71,7 +72,10 @@ var _ = Describe("Server", func() {
BeforeEach(func() {
conn = &mockPacketConn{}
config = &Config{}
config = &Config{
TLSConfig: &tls.Config{},
Versions: protocol.SupportedVersions,
}
})
Context("with mock session", func() {
@@ -105,9 +109,9 @@ var _ = Describe("Server", func() {
It("composes version negotiation packets", func() {
expected := append(
[]byte{0x01 | 0x08, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
protocol.SupportedVersionsAsTags...,
[]byte{'Q', '0', '9', '9'}...,
)
Expect(composeVersionNegotiation(1)).To(Equal(expected))
Expect(composeVersionNegotiation(1, []protocol.VersionNumber{99})).To(Equal(expected))
})
It("creates new sessions", func() {
@@ -260,7 +264,7 @@ var _ = Describe("Server", func() {
Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1))
b := &bytes.Buffer{}
// add an unsupported version
utils.WriteUint32(b, protocol.VersionNumberToTag(protocol.SupportedVersions[0]-2))
utils.WriteUint32(b, protocol.VersionNumberToTag(protocol.SupportedVersions[0]+1))
data := []byte{0x09, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c}
data = append(append(data, b.Bytes()...), 0x01)
err = serv.handlePacket(nil, nil, data)
@@ -320,16 +324,29 @@ var _ = Describe("Server", func() {
})
It("setups with the right values", func() {
var connStateCallback ConnStateCallback = func(_ Session, _ ConnState) {}
supportedVersions := []protocol.VersionNumber{1, 3, 5}
config := Config{
ConnState: func(_ Session, _ ConnState) {},
TLSConfig: &tls.Config{},
ConnState: connStateCallback,
Versions: supportedVersions,
}
ln, err := Listen(conn, &config)
server := ln.(*server)
Expect(err).ToNot(HaveOccurred())
server := ln.(*server)
Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout))
Expect(server.sessions).ToNot(BeNil())
Expect(server.scfg).ToNot(BeNil())
Expect(server.config).To(Equal(&config))
Expect(server.config.ConnState).ToNot(BeNil())
Expect(server.config.Versions).To(Equal(supportedVersions))
})
It("fills in default values if options are not set in the Config", func() {
config := Config{TLSConfig: &tls.Config{}}
ln, err := Listen(conn, &config)
Expect(err).ToNot(HaveOccurred())
server := ln.(*server)
Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
})
It("listens on a given address", func() {
@@ -353,6 +370,7 @@ var _ = Describe("Server", func() {
})
It("setups and responds with version negotiation", func() {
config.Versions = []protocol.VersionNumber{99}
b := &bytes.Buffer{}
hdr := PublicHeader{
VersionFlag: true,
@@ -375,9 +393,11 @@ var _ = Describe("Server", func() {
Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero())
Expect(conn.dataWrittenTo).To(Equal(udpAddr))
b = &bytes.Buffer{}
utils.WriteUint32(b, protocol.VersionNumberToTag(99))
expected := append(
[]byte{0x9, 0x37, 0x13, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
protocol.SupportedVersionsAsTags...,
b.Bytes()...,
)
Expect(conn.dataWritten.Bytes()).To(Equal(expected))
Expect(returned).To(BeFalse())

View File

@@ -98,7 +98,7 @@ type session struct {
var _ Session = &session{}
// newSession makes a new session
func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) {
func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback, supportedVersions []protocol.VersionNumber) (packetHandler, error) {
s := &session{
conn: conn,
connectionID: connectionID,
@@ -119,7 +119,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
sourceAddr = []byte(conn.RemoteAddr().String())
}
var err error
s.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, s.connectionParameters, s.aeadChanged)
s.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, s.connectionParameters, supportedVersions, s.aeadChanged)
if err != nil {
return nil, err
}

View File

@@ -149,6 +149,7 @@ var _ = Describe("Session", func() {
0,
scfg,
func(Session, bool) {},
nil,
)
Expect(err).NotTo(HaveOccurred())
sess = pSess.(*session)
@@ -183,6 +184,7 @@ var _ = Describe("Session", func() {
0,
scfg,
func(Session, bool) {},
nil,
)
Expect(err).ToNot(HaveOccurred())
Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte{192, 168, 100, 200}))
@@ -198,6 +200,7 @@ var _ = Describe("Session", func() {
0,
scfg,
func(Session, bool) {},
nil,
)
Expect(err).ToNot(HaveOccurred())
Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte("192.168.100.200:1337")))