diff --git a/crypto/source_address_token.go b/crypto/source_address_token.go index dea1d2b2..fefc3f56 100644 --- a/crypto/source_address_token.go +++ b/crypto/source_address_token.go @@ -18,6 +18,14 @@ import ( "golang.org/x/crypto/hkdf" ) +// StkSource is used to create and verify source address tokens +type StkSource interface { + // NewToken creates a new token for a given IP address + NewToken(ip net.IP) ([]byte, error) + // VerifyToken verifies if a token matches a given IP address and is not outdated + VerifyToken(ip net.IP, data []byte) error +} + type sourceAddressToken struct { ip net.IP // unix timestamp in seconds @@ -51,7 +59,8 @@ const stkKeySize = 16 // at 16 :) const stkNonceSize = 16 -func newStkSource(secret []byte) (*stkSource, error) { +// NewStkSource creates a source for source address tokens +func NewStkSource(secret []byte) (StkSource, error) { key, err := deriveKey(secret) if err != nil { return nil, err diff --git a/crypto/source_address_token_test.go b/crypto/source_address_token_test.go index 496eb1d6..2a545fa3 100644 --- a/crypto/source_address_token_test.go +++ b/crypto/source_address_token_test.go @@ -52,7 +52,8 @@ var _ = Describe("Source Address Tokens", func() { Expect(ip6).NotTo(BeEmpty()) secret = []byte("TESTING") - source, err = newStkSource(secret) + sourceI, err := NewStkSource(secret) + source = sourceI.(*stkSource) Expect(err).NotTo(HaveOccurred()) }) diff --git a/handshake/crypto_setup.go b/handshake/crypto_setup.go index 309ab74f..cb56a9fc 100644 --- a/handshake/crypto_setup.go +++ b/handshake/crypto_setup.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/rand" "io" + "net" "sync" "github.com/lucas-clemente/quic-go/crypto" @@ -21,6 +22,7 @@ type KeyExchangeFunction func() (crypto.KeyExchange, error) // The CryptoSetup handles all things crypto for the Session type CryptoSetup struct { connID protocol.ConnectionID + ip net.IP version protocol.VersionNumber scfg *ServerConfig nonce []byte @@ -45,7 +47,15 @@ type CryptoSetup struct { var _ crypto.AEAD = &CryptoSetup{} // NewCryptoSetup creates a new CryptoSetup instance -func NewCryptoSetup(connID protocol.ConnectionID, version protocol.VersionNumber, scfg *ServerConfig, cryptoStream utils.Stream, connectionParametersManager *ConnectionParametersManager, aeadChanged chan struct{}) (*CryptoSetup, error) { +func NewCryptoSetup( + connID protocol.ConnectionID, + ip net.IP, + version protocol.VersionNumber, + scfg *ServerConfig, + cryptoStream utils.Stream, + connectionParametersManager *ConnectionParametersManager, + aeadChanged chan struct{}, +) (*CryptoSetup, error) { nonce := make([]byte, 32) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return nil, err @@ -56,6 +66,7 @@ func NewCryptoSetup(connID protocol.ConnectionID, version protocol.VersionNumber } return &CryptoSetup{ connID: connID, + ip: ip, version: version, scfg: scfg, nonce: nonce, @@ -81,41 +92,53 @@ func (h *CryptoSetup) HandleCryptoStream() error { } chloData := cachingReader.Get() - utils.Infof("Got crypto message:\n%s", printHandshakeMessage(cryptoData)) + utils.Infof("Got CHLO:\n%s", printHandshakeMessage(cryptoData)) - sniSlice, ok := cryptoData[TagSNI] - if !ok { - return qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required") - } - sni := string(sniSlice) - if sni == "" { - return qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required") - } - - var reply []byte - if !h.isInchoateCHLO(cryptoData) { - // We have a CHLO with a proper server config ID, do a 0-RTT handshake - reply, err = h.handleCHLO(sni, chloData, cryptoData) - if err != nil { - return err - } - _, err = h.cryptoStream.Write(reply) - if err != nil { - return err - } - return nil - } - - // We have an inchoate or non-matching CHLO, we now send a rejection - reply, err = h.handleInchoateCHLO(sni, chloData, cryptoData) + done, err := h.handleMessage(chloData, cryptoData) if err != nil { return err } + if done { + return nil + } + } +} + +func (h *CryptoSetup) handleMessage(chloData []byte, cryptoData map[Tag][]byte) (bool, error) { + sniSlice, ok := cryptoData[TagSNI] + if !ok { + return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required") + } + sni := string(sniSlice) + if sni == "" { + return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required") + } + + var reply []byte + var err error + if !h.isInchoateCHLO(cryptoData) { + // We have a CHLO with a proper server config ID, do a 0-RTT handshake + reply, err = h.handleCHLO(sni, chloData, cryptoData) + if err != nil { + return false, err + } _, err = h.cryptoStream.Write(reply) if err != nil { - return err + return false, err } + return true, nil } + + // We have an inchoate or non-matching CHLO, we now send a rejection + reply, err = h.handleInchoateCHLO(sni, chloData, cryptoData) + if err != nil { + return false, err + } + _, err = h.cryptoStream.Write(reply) + if err != nil { + return false, err + } + return false, nil } // Open a message @@ -165,6 +188,10 @@ func (h *CryptoSetup) isInchoateCHLO(cryptoData map[Tag][]byte) bool { if !ok || !bytes.Equal(h.scfg.ID, scid) { return true } + if err := h.scfg.stkSource.VerifyToken(h.ip, cryptoData[TagSTK]); err != nil { + utils.Infof("STK invalid: %s", err.Error()) + return false + } return false } @@ -191,11 +218,17 @@ func (h *CryptoSetup) handleInchoateCHLO(sni string, data []byte, cryptoData map return nil, err } + token, err := h.scfg.stkSource.NewToken(h.ip) + if err != nil { + return nil, err + } + var serverReply bytes.Buffer WriteHandshakeMessage(&serverReply, TagREJ, map[Tag][]byte{ TagSCFG: h.scfg.Get(), TagCERT: certCompressed, TagPROF: proof, + TagSTK: token, }) return serverReply.Bytes(), nil } @@ -285,3 +318,11 @@ func (h *CryptoSetup) DiversificationNonce() []byte { } return h.diversificationNonce } + +func (h *CryptoSetup) verifyOrCreateSTK(token []byte) ([]byte, error) { + err := h.scfg.stkSource.VerifyToken(h.ip, token) + if err != nil { + return h.scfg.stkSource.NewToken(h.ip) + } + return token, nil +} diff --git a/handshake/crypto_setup_test.go b/handshake/crypto_setup_test.go index 94d80024..26078df0 100644 --- a/handshake/crypto_setup_test.go +++ b/handshake/crypto_setup_test.go @@ -3,6 +3,7 @@ package handshake import ( "bytes" "errors" + "net" "github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/protocol" @@ -102,6 +103,26 @@ func (s *mockStream) Close() error { panic("not implemente func (mockStream) CloseRemote(offset protocol.ByteCount) { panic("not implemented") } func (s mockStream) StreamID() protocol.StreamID { panic("not implemented") } +type mockStkSource struct{} + +func (mockStkSource) NewToken(ip net.IP) ([]byte, error) { + return append([]byte("token "), ip...), nil +} + +func (mockStkSource) VerifyToken(ip net.IP, token []byte) error { + split := bytes.Split(token, []byte(" ")) + if len(split) != 2 { + return errors.New("stk required") + } + if !bytes.Equal(split[0], []byte("token")) { + return errors.New("no prefix match") + } + if !bytes.Equal(split[1], ip) { + return errors.New("ip wrong") + } + return nil +} + var _ = Describe("Crypto setup", func() { var ( kex *mockKEX @@ -112,22 +133,28 @@ var _ = Describe("Crypto setup", func() { cpm *ConnectionParametersManager aeadChanged chan struct{} nonce32 []byte + ip net.IP + validSTK []byte ) BeforeEach(func() { + var err error + ip = net.ParseIP("1.2.3.4") + validSTK, err = mockStkSource{}.NewToken(ip) + Expect(err).NotTo(HaveOccurred()) nonce32 = make([]byte, 32) expectedInitialNonceLen = 32 expectedFSNonceLen = 64 - var err error aeadChanged = make(chan struct{}, 1) stream = &mockStream{} kex = &mockKEX{} signer = &mockSigner{} scfg, err = NewServerConfig(kex, signer) Expect(err).NotTo(HaveOccurred()) + scfg.stkSource = &mockStkSource{} v := protocol.SupportedVersions[len(protocol.SupportedVersions)-1] cpm = NewConnectionParamatersManager() - cs, err = NewCryptoSetup(protocol.ConnectionID(42), v, scfg, stream, cpm, aeadChanged) + cs, err = NewCryptoSetup(protocol.ConnectionID(42), ip, v, scfg, stream, cpm, aeadChanged) Expect(err).NotTo(HaveOccurred()) cs.keyDerivation = mockKeyDerivation cs.keyExchange = func() (crypto.KeyExchange, error) { return &mockKEX{ephermal: true}, nil } @@ -207,12 +234,14 @@ var _ = Describe("Crypto setup", func() { It("handles long handshake", func() { WriteHandshakeMessage(&stream.dataToRead, TagCHLO, map[Tag][]byte{ TagSNI: []byte("quic.clemente.io"), + TagSTK: validSTK, TagPAD: bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), }) WriteHandshakeMessage(&stream.dataToRead, TagCHLO, map[Tag][]byte{ TagSCID: scfg.ID, TagSNI: []byte("quic.clemente.io"), TagNONC: nonce32, + TagSTK: validSTK, }) err := cs.HandleCryptoStream() Expect(err).NotTo(HaveOccurred()) @@ -226,6 +255,7 @@ var _ = Describe("Crypto setup", func() { TagSCID: scfg.ID, TagSNI: []byte("quic.clemente.io"), TagNONC: nonce32, + TagSTK: validSTK, }) err := cs.HandleCryptoStream() Expect(err).NotTo(HaveOccurred()) @@ -249,7 +279,9 @@ var _ = Describe("Crypto setup", func() { }) It("errors without SNI", func() { - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, map[Tag][]byte{}) + WriteHandshakeMessage(&stream.dataToRead, TagCHLO, map[Tag][]byte{ + TagSTK: validSTK, + }) err := cs.HandleCryptoStream() Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required")) }) @@ -338,4 +370,34 @@ var _ = Describe("Crypto setup", func() { }) }) }) + + Context("STK verification and creation", func() { + It("requires STK", func() { + done, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{ + TagSNI: []byte("foo"), + }) + Expect(done).To(BeFalse()) + Expect(err).To(BeNil()) + Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK))) + }) + + It("works with proper STK", func() { + done, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{ + TagSTK: validSTK, + TagSNI: []byte("foo"), + }) + Expect(done).To(BeFalse()) + Expect(err).To(BeNil()) + }) + + It("errors if IP does not match", func() { + done, err := cs.handleMessage(bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), map[Tag][]byte{ + TagSNI: []byte("foo"), + TagSTK: []byte("token \x04\x03\x03\x01"), + }) + Expect(done).To(BeFalse()) + Expect(err).To(BeNil()) + Expect(stream.dataWritten.Bytes()).To(ContainSubstring(string(validSTK))) + }) + }) }) diff --git a/handshake/server_config.go b/handshake/server_config.go index 38623a51..c3de7bd1 100644 --- a/handshake/server_config.go +++ b/handshake/server_config.go @@ -10,9 +10,10 @@ import ( // ServerConfig is a server config type ServerConfig struct { - kex crypto.KeyExchange - signer crypto.Signer - ID []byte + kex crypto.KeyExchange + signer crypto.Signer + ID []byte + stkSource crypto.StkSource } // NewServerConfig creates a new server config @@ -22,10 +23,21 @@ func NewServerConfig(kex crypto.KeyExchange, signer crypto.Signer) (*ServerConfi if err != nil { return nil, err } + + stkSecret := make([]byte, 32) + if _, err = io.ReadFull(rand.Reader, stkSecret); err != nil { + return nil, err + } + stkSource, err := crypto.NewStkSource(stkSecret) + if err != nil { + return nil, err + } + return &ServerConfig{ - kex: kex, - signer: signer, - ID: id, + kex: kex, + signer: signer, + ID: id, + stkSource: stkSource, }, nil } diff --git a/session.go b/session.go index 0465016d..81852547 100644 --- a/session.go +++ b/session.go @@ -111,7 +111,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol cryptoStream, _ := session.OpenStream(1) var err error - session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged) + session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, conn.IP(), v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged) if err != nil { return nil, err } diff --git a/session_test.go b/session_test.go index f2e152e6..8157750d 100644 --- a/session_test.go +++ b/session_test.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "io" + "net" "runtime" "sync/atomic" "time" @@ -31,6 +32,7 @@ func (m *mockConnection) write(p []byte) error { } func (*mockConnection) setCurrentRemoteAddr(addr interface{}) {} +func (*mockConnection) IP() net.IP { return nil } var _ = Describe("Session", func() { var ( diff --git a/udp_conn.go b/udp_conn.go index 1e2938ee..ea1cc7ed 100644 --- a/udp_conn.go +++ b/udp_conn.go @@ -5,6 +5,7 @@ import "net" type connection interface { write([]byte) error setCurrentRemoteAddr(interface{}) + IP() net.IP } type udpConn struct { @@ -22,3 +23,7 @@ func (c *udpConn) write(p []byte) error { func (c *udpConn) setCurrentRemoteAddr(addr interface{}) { c.currentAddr = addr.(*net.UDPAddr) } + +func (c *udpConn) IP() net.IP { + return c.currentAddr.IP +}