From 79cd3235f87f6fece60959c2b89f626d341f3ff5 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 4 Mar 2019 19:00:33 +0900 Subject: [PATCH 1/9] log the stateless reset token in the transport parameters --- internal/handshake/transport_parameter_test.go | 3 ++- internal/handshake/transport_parameters.go | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/internal/handshake/transport_parameter_test.go b/internal/handshake/transport_parameter_test.go index 9e983a27..7e3d9b46 100644 --- a/internal/handshake/transport_parameter_test.go +++ b/internal/handshake/transport_parameter_test.go @@ -24,8 +24,9 @@ var _ = Describe("Transport Parameters", func() { IdleTimeout: 42 * time.Second, OriginalConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, AckDelayExponent: 14, + StatelessResetToken: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}, } - Expect(p.String()).To(Equal("&handshake.TransportParameters{OriginalConnectionID: 0xdeadbeef, InitialMaxStreamDataBidiLocal: 0x1234, InitialMaxStreamDataBidiRemote: 0x2345, InitialMaxStreamDataUni: 0x3456, InitialMaxData: 0x4567, MaxBidiStreams: 1337, MaxUniStreams: 7331, IdleTimeout: 42s, AckDelayExponent: 14}")) + Expect(p.String()).To(Equal("&handshake.TransportParameters{OriginalConnectionID: 0xdeadbeef, InitialMaxStreamDataBidiLocal: 0x1234, InitialMaxStreamDataBidiRemote: 0x2345, InitialMaxStreamDataUni: 0x3456, InitialMaxData: 0x4567, MaxBidiStreams: 1337, MaxUniStreams: 7331, IdleTimeout: 42s, AckDelayExponent: 14, StatelessResetToken: 0xdeadbeefcafe}")) }) getRandomValue := func() uint64 { diff --git a/internal/handshake/transport_parameters.go b/internal/handshake/transport_parameters.go index bdb0d917..69354f33 100644 --- a/internal/handshake/transport_parameters.go +++ b/internal/handshake/transport_parameters.go @@ -231,5 +231,10 @@ func (p *TransportParameters) marshal(b *bytes.Buffer) { // String returns a string representation, intended for logging. func (p *TransportParameters) String() string { - return fmt.Sprintf("&handshake.TransportParameters{OriginalConnectionID: %s, InitialMaxStreamDataBidiLocal: %#x, InitialMaxStreamDataBidiRemote: %#x, InitialMaxStreamDataUni: %#x, InitialMaxData: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s, AckDelayExponent: %d}", p.OriginalConnectionID, p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout, p.AckDelayExponent) + logString := "&handshake.TransportParameters{OriginalConnectionID: %s, InitialMaxStreamDataBidiLocal: %#x, InitialMaxStreamDataBidiRemote: %#x, InitialMaxStreamDataUni: %#x, InitialMaxData: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s, AckDelayExponent: %d" + if len(p.StatelessResetToken) > 0 { // the client never sends a stateless reset token + logString += ", StatelessResetToken: %#x" + } + logString += "}" + return fmt.Sprintf(logString, p.OriginalConnectionID, p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout, p.AckDelayExponent, p.StatelessResetToken) } From 759cd289c9b2fb967e26279ccf797af939c77cb9 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 5 Mar 2019 10:08:55 +0900 Subject: [PATCH 2/9] accept stateless resets for connection with zero-length connection IDs --- packet_handler_map.go | 34 +++++++++++++++++++++++----------- packet_handler_map_test.go | 15 +++++++++++++++ 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/packet_handler_map.go b/packet_handler_map.go index f3e1d019..e89baf49 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -183,6 +183,10 @@ func (h *packetHandlerMap) handlePacket( h.mutex.RLock() defer h.mutex.RUnlock() + if isStatelessReset := h.maybeHandleStatelessReset(data); isStatelessReset { + return + } + handlerEntry, handlerFound := h.handlers[string(connID)] p := &receivedPacket{ @@ -195,17 +199,7 @@ func (h *packetHandlerMap) handlePacket( handlerEntry.handler.handlePacket(p) return } - // No session found. - // This might be a stateless reset. - if data[0]&0x80 == 0 { // stateless resets are always short header packets - if len(p.data) >= protocol.MinStatelessResetSize { - var token [16]byte - copy(token[:], p.data[len(p.data)-16:]) - if sess, ok := h.resetTokens[token]; ok { - sess.destroy(errors.New("received a stateless reset")) - return - } - } + if data[0]&0x80 == 0 { // TODO(#943): send a stateless reset h.logger.Debugf("received a short header packet with an unexpected connection ID %s", connID) return @@ -216,3 +210,21 @@ func (h *packetHandlerMap) handlePacket( } h.server.handlePacket(p) } + +func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { + // stateless resets are always short header packets + if data[0]&0x80 != 0 { + return false + } + if len(data) < protocol.MinStatelessResetSize { + return false + } + + var token [16]byte + copy(token[:], data[len(data)-16:]) + if sess, ok := h.resetTokens[token]; ok { + sess.destroy(errors.New("received a stateless reset")) + return true + } + return false +} diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 2f1e03a8..00a2d40e 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -195,6 +195,21 @@ var _ = Describe("Packet Handler Map", func() { Eventually(destroyed).Should(BeClosed()) }) + It("handles stateless resets for 0-length connection IDs", func() { + handler.connIDLen = 0 + packetHandler := NewMockPacketHandler(mockCtrl) + token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddWithResetToken(protocol.ConnectionID{}, packetHandler, token) + packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) + packet = append(packet, token[:]...) + destroyed := make(chan struct{}) + packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) { + close(destroyed) + }) + conn.dataToRead <- packet + Eventually(destroyed).Should(BeClosed()) + }) + It("deletes reset tokens when the session is retired", func() { handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} From 733dcb75eb43828c04e78244c53fcc76f12fc9d4 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 5 Mar 2019 13:48:01 +0900 Subject: [PATCH 3/9] rename the methods of the sessionRunner interface --- client.go | 3 +- client_test.go | 4 +- mock_session_runner_test.go | 36 +++++++++--------- server.go | 17 ++++----- server_test.go | 8 ++-- session.go | 8 ++-- session_test.go | 74 ++++++++++++++++++------------------- 7 files changed, 73 insertions(+), 77 deletions(-) diff --git a/client.go b/client.go index 0ae2564a..dc70677e 100644 --- a/client.go +++ b/client.go @@ -358,9 +358,8 @@ func (c *client) createNewTLSSession(version protocol.VersionNumber) error { c.mutex.Lock() defer c.mutex.Unlock() runner := &runner{ + packetHandlerManager: c.packetHandlers, onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) }, - retireConnectionIDImpl: c.packetHandlers.Retire, - removeConnectionIDImpl: c.packetHandlers.Remove, } sess, err := newClientSession( c.conn, diff --git a/client_test.go b/client_test.go index 3a63ef53..4a857d06 100644 --- a/client_test.go +++ b/client_test.go @@ -204,7 +204,7 @@ var _ = Describe("Client", func() { ) (quicSession, error) { sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run().Do(func() { close(run) }) - runner.onHandshakeComplete(sess) + runner.OnHandshakeComplete(sess) return sess, nil } s, err := Dial( @@ -325,7 +325,7 @@ var _ = Describe("Client", func() { return sess, nil } sess.EXPECT().run().Do(func() { - runner.retireConnectionID(connID) + runner.Retire(connID) }) _, err := DialContext( diff --git a/mock_session_runner_test.go b/mock_session_runner_test.go index d1efee9c..ab88320f 100644 --- a/mock_session_runner_test.go +++ b/mock_session_runner_test.go @@ -34,38 +34,38 @@ func (m *MockSessionRunner) EXPECT() *MockSessionRunnerMockRecorder { return m.recorder } -// onHandshakeComplete mocks base method -func (m *MockSessionRunner) onHandshakeComplete(arg0 Session) { +// OnHandshakeComplete mocks base method +func (m *MockSessionRunner) OnHandshakeComplete(arg0 Session) { m.ctrl.T.Helper() - m.ctrl.Call(m, "onHandshakeComplete", arg0) + m.ctrl.Call(m, "OnHandshakeComplete", arg0) } -// onHandshakeComplete indicates an expected call of onHandshakeComplete -func (mr *MockSessionRunnerMockRecorder) onHandshakeComplete(arg0 interface{}) *gomock.Call { +// OnHandshakeComplete indicates an expected call of OnHandshakeComplete +func (mr *MockSessionRunnerMockRecorder) OnHandshakeComplete(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "onHandshakeComplete", reflect.TypeOf((*MockSessionRunner)(nil).onHandshakeComplete), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnHandshakeComplete", reflect.TypeOf((*MockSessionRunner)(nil).OnHandshakeComplete), arg0) } -// removeConnectionID mocks base method -func (m *MockSessionRunner) removeConnectionID(arg0 protocol.ConnectionID) { +// Remove mocks base method +func (m *MockSessionRunner) Remove(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() - m.ctrl.Call(m, "removeConnectionID", arg0) + m.ctrl.Call(m, "Remove", arg0) } -// removeConnectionID indicates an expected call of removeConnectionID -func (mr *MockSessionRunnerMockRecorder) removeConnectionID(arg0 interface{}) *gomock.Call { +// Remove indicates an expected call of Remove +func (mr *MockSessionRunnerMockRecorder) Remove(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "removeConnectionID", reflect.TypeOf((*MockSessionRunner)(nil).removeConnectionID), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSessionRunner)(nil).Remove), arg0) } -// retireConnectionID mocks base method -func (m *MockSessionRunner) retireConnectionID(arg0 protocol.ConnectionID) { +// Retire mocks base method +func (m *MockSessionRunner) Retire(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() - m.ctrl.Call(m, "retireConnectionID", arg0) + m.ctrl.Call(m, "Retire", arg0) } -// retireConnectionID indicates an expected call of retireConnectionID -func (mr *MockSessionRunnerMockRecorder) retireConnectionID(arg0 interface{}) *gomock.Call { +// Retire indicates an expected call of Retire +func (mr *MockSessionRunnerMockRecorder) Retire(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "retireConnectionID", reflect.TypeOf((*MockSessionRunner)(nil).retireConnectionID), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Retire", reflect.TypeOf((*MockSessionRunner)(nil).Retire), arg0) } diff --git a/server.go b/server.go index b4f4b0ac..59c534d7 100644 --- a/server.go +++ b/server.go @@ -52,20 +52,18 @@ type quicSession interface { } type sessionRunner interface { - onHandshakeComplete(Session) - retireConnectionID(protocol.ConnectionID) - removeConnectionID(protocol.ConnectionID) + OnHandshakeComplete(Session) + Retire(protocol.ConnectionID) + Remove(protocol.ConnectionID) } type runner struct { + packetHandlerManager + onHandshakeCompleteImpl func(Session) - retireConnectionIDImpl func(protocol.ConnectionID) - removeConnectionIDImpl func(protocol.ConnectionID) } -func (r *runner) onHandshakeComplete(s Session) { r.onHandshakeCompleteImpl(s) } -func (r *runner) retireConnectionID(c protocol.ConnectionID) { r.retireConnectionIDImpl(c) } -func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectionIDImpl(c) } +func (r *runner) OnHandshakeComplete(s Session) { r.onHandshakeCompleteImpl(s) } var _ sessionRunner = &runner{} @@ -169,6 +167,7 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, func (s *server) setup() error { s.sessionRunner = &runner{ + packetHandlerManager: s.sessionHandler, onHandshakeCompleteImpl: func(sess Session) { go func() { atomic.AddInt32(&s.sessionQueueLen, 1) @@ -181,8 +180,6 @@ func (s *server) setup() error { } }() }, - retireConnectionIDImpl: s.sessionHandler.Retire, - removeConnectionIDImpl: s.sessionHandler.Remove, } cookieGenerator, err := handshake.NewCookieGenerator() if err != nil { diff --git a/server_test.go b/server_test.go index ca8a8801..7313f32b 100644 --- a/server_test.go +++ b/server_test.go @@ -346,7 +346,7 @@ var _ = Describe("Server", func() { sess.EXPECT().handlePacket(p) sess.EXPECT().run() sess.EXPECT().Context().Return(context.Background()) - runner.onHandshakeComplete(sess) + runner.OnHandshakeComplete(sess) return sess, nil } @@ -403,7 +403,7 @@ var _ = Describe("Server", func() { sess.EXPECT().handlePacket(p) sess.EXPECT().run() sess.EXPECT().Context().Return(ctx) - runner.onHandshakeComplete(sess) + runner.OnHandshakeComplete(sess) close(sessionCreated) return sess, nil } @@ -489,7 +489,7 @@ var _ = Describe("Server", func() { ) (quicSession, error) { go func() { <-completeHandshake - runner.onHandshakeComplete(sess) + runner.OnHandshakeComplete(sess) }() sess.EXPECT().run().Do(func() {}) sess.EXPECT().Context().Return(context.Background()) @@ -521,7 +521,7 @@ var _ = Describe("Server", func() { sess := NewMockQuicSession(mockCtrl) sess.EXPECT().run().Do(func() {}) sess.EXPECT().Context().Return(context.Background()) - runner.onHandshakeComplete(sess) + runner.OnHandshakeComplete(sess) done <- struct{}{} return sess, nil } diff --git a/session.go b/session.go index 6766c6d3..70ccc649 100644 --- a/session.go +++ b/session.go @@ -473,7 +473,7 @@ func (s *session) idleTimeoutStartTime() time.Time { func (s *session) handleHandshakeComplete() { s.handshakeComplete = true s.handshakeCompleteChan = nil // prevent this case from ever being selected again - s.sessionRunner.onHandshakeComplete(s) + s.sessionRunner.OnHandshakeComplete(s) // The client completes the handshake first (after sending the CFIN). // We need to make sure they learn about the peer completing the handshake, @@ -835,7 +835,7 @@ func (s *session) closeLocal(e error) { } else { s.logger.Errorf("Closing session with error: %s", e) } - s.sessionRunner.retireConnectionID(s.srcConnID) + s.sessionRunner.Retire(s.srcConnID) s.closeChan <- closeError{err: e, sendClose: true, remote: false} }) } @@ -848,7 +848,7 @@ func (s *session) destroy(e error) { } else { s.logger.Errorf("Destroying session with error: %s", e) } - s.sessionRunner.removeConnectionID(s.srcConnID) + s.sessionRunner.Remove(s.srcConnID) s.closeChan <- closeError{err: e, sendClose: false, remote: false} }) } @@ -864,7 +864,7 @@ func (s *session) closeForRecreating() protocol.PacketNumber { func (s *session) closeRemote(e error) { s.closeOnce.Do(func() { s.logger.Errorf("Peer closed session with error: %s", e) - s.sessionRunner.removeConnectionID(s.srcConnID) + s.sessionRunner.Remove(s.srcConnID) s.closeChan <- closeError{err: e, remote: true} }) } diff --git a/session_test.go b/session_test.go index b31f63e5..f263cba0 100644 --- a/session_test.go +++ b/session_test.go @@ -322,7 +322,7 @@ var _ = Describe("Session", func() { It("handles CONNECTION_CLOSE frames", func() { testErr := qerr.Error(qerr.StreamLimitError, "foobar") streamManager.EXPECT().CloseWithError(testErr) - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sessionRunner.EXPECT().Remove(gomock.Any()) cryptoSetup.EXPECT().Close() go func() { @@ -377,7 +377,7 @@ var _ = Describe("Session", func() { It("shuts down without error", func() { streamManager.EXPECT().CloseWithError(qerr.Error(qerr.NoError, "")) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{raw: []byte("connection close")}, nil) Expect(sess.Close()).To(Succeed()) @@ -389,7 +389,7 @@ var _ = Describe("Session", func() { It("only closes once", func() { streamManager.EXPECT().CloseWithError(qerr.Error(qerr.NoError, "")) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) Expect(sess.Close()).To(Succeed()) @@ -402,7 +402,7 @@ var _ = Describe("Session", func() { It("closes streams with proper error", func() { testErr := errors.New("test error") streamManager.EXPECT().CloseWithError(qerr.Error(0x1337, testErr.Error())) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sess.CloseWithError(0x1337, testErr) @@ -412,7 +412,7 @@ var _ = Describe("Session", func() { It("closes the session in order to recreate it", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sessionRunner.EXPECT().Remove(gomock.Any()) cryptoSetup.EXPECT().Close() sess.closeForRecreating() Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent @@ -423,7 +423,7 @@ var _ = Describe("Session", func() { It("destroys the session", func() { testErr := errors.New("close") streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sessionRunner.EXPECT().Remove(gomock.Any()) cryptoSetup.EXPECT().Close() sess.destroy(testErr) Eventually(areSessionsRunning).Should(BeFalse()) @@ -433,7 +433,7 @@ var _ = Describe("Session", func() { It("cancels the context when the run loop exists", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) returned := make(chan struct{}) @@ -451,7 +451,7 @@ var _ = Describe("Session", func() { It("retransmits the CONNECTION_CLOSE packet if packets are arriving late", func() { streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{raw: []byte("foobar")}, nil) sess.Close() @@ -547,7 +547,7 @@ var _ = Describe("Session", func() { cryptoSetup.EXPECT().RunHandshake().Do(func() { <-sess.Context().Done() }) sess.run() }() - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) sess.handlePacket(getPacket(&wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: sess.srcConnID}, PacketNumberLen: protocol.PacketNumberLen1, @@ -574,7 +574,7 @@ var _ = Describe("Session", func() { Expect(err).To(MatchError("PROTOCOL_VIOLATION: empty packet")) close(done) }() - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) sess.handlePacket(getPacket(&wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: sess.srcConnID}, PacketNumberLen: protocol.PacketNumberLen1, @@ -887,7 +887,7 @@ var _ = Describe("Session", func() { Consistently(mconn.written).Should(HaveLen(2)) // make the go routine return packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) cryptoSetup.EXPECT().Close() sess.Close() Eventually(done).Should(BeClosed()) @@ -914,7 +914,7 @@ var _ = Describe("Session", func() { Consistently(mconn.written).Should(HaveLen(1)) // make the go routine return packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) cryptoSetup.EXPECT().Close() sess.Close() Eventually(done).Should(BeClosed()) @@ -943,7 +943,7 @@ var _ = Describe("Session", func() { Eventually(mconn.written, 2*pacingDelay).Should(HaveLen(2)) // make the go routine return packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) cryptoSetup.EXPECT().Close() sess.Close() Eventually(done).Should(BeClosed()) @@ -969,7 +969,7 @@ var _ = Describe("Session", func() { Eventually(mconn.written).Should(HaveLen(3)) // make the go routine return packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) cryptoSetup.EXPECT().Close() sess.Close() Eventually(done).Should(BeClosed()) @@ -990,7 +990,7 @@ var _ = Describe("Session", func() { sess.scheduleSending() // no packet will get sent Consistently(mconn.written).ShouldNot(Receive()) // make the go routine return - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() sess.Close() @@ -1018,7 +1018,7 @@ var _ = Describe("Session", func() { sess.scheduleSending() Eventually(mconn.written).Should(Receive()) // make the go routine return - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() @@ -1052,7 +1052,7 @@ var _ = Describe("Session", func() { Eventually(mconn.written).Should(Receive()) // make sure the go routine returns packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() sess.Close() @@ -1064,7 +1064,7 @@ var _ = Describe("Session", func() { It("closes when RunHandshake() errors", func() { testErr := errors.New("crypto setup error") streamManager.EXPECT().CloseWithError(qerr.Error(qerr.InternalError, testErr.Error())) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) cryptoSetup.EXPECT().Close() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) go func() { @@ -1080,13 +1080,13 @@ var _ = Describe("Session", func() { packer.EXPECT().PackPacket().AnyTimes() go func() { defer GinkgoRecover() - sessionRunner.EXPECT().onHandshakeComplete(gomock.Any()) + sessionRunner.EXPECT().OnHandshakeComplete(gomock.Any()) cryptoSetup.EXPECT().RunHandshake() sess.run() }() Consistently(sess.Context().Done()).ShouldNot(BeClosed()) // make sure the go routine returns - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() @@ -1097,7 +1097,7 @@ var _ = Describe("Session", func() { It("sends a forward-secure packet when the handshake completes", func() { done := make(chan struct{}) gomock.InOrder( - sessionRunner.EXPECT().onHandshakeComplete(gomock.Any()), + sessionRunner.EXPECT().OnHandshakeComplete(gomock.Any()), packer.EXPECT().PackPacket().DoAndReturn(func() (*packedPacket, error) { defer close(done) return &packedPacket{ @@ -1115,7 +1115,7 @@ var _ = Describe("Session", func() { Eventually(done).Should(BeClosed()) //make sure the go routine returns streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() Expect(sess.Close()).To(Succeed()) @@ -1131,7 +1131,7 @@ var _ = Describe("Session", func() { close(done) }() streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() Expect(sess.Close()).To(Succeed()) @@ -1149,7 +1149,7 @@ var _ = Describe("Session", func() { close(done) }() streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() Expect(sess.CloseWithError(0x1337, testErr)).To(Succeed()) @@ -1166,7 +1166,7 @@ var _ = Describe("Session", func() { Expect(err.Error()).To(ContainSubstring("transport parameter")) }() streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() sess.processTransportParameters([]byte("invalid")) @@ -1195,7 +1195,7 @@ var _ = Describe("Session", func() { sess.processTransportParameters(chtp.Marshal()) // make the go routine return streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() sess.Close() @@ -1250,7 +1250,7 @@ var _ = Describe("Session", func() { }() Eventually(sent).Should(BeClosed()) // make the go routine return - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() @@ -1271,7 +1271,7 @@ var _ = Describe("Session", func() { }() Consistently(mconn.written).ShouldNot(Receive()) // make the go routine return - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() @@ -1292,7 +1292,7 @@ var _ = Describe("Session", func() { }() Consistently(mconn.written).ShouldNot(Receive()) // make the go routine return - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) streamManager.EXPECT().CloseWithError(gomock.Any()) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() @@ -1307,7 +1307,7 @@ var _ = Describe("Session", func() { }) It("times out due to no network activity", func() { - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sessionRunner.EXPECT().Remove(gomock.Any()) sess.handshakeComplete = true sess.lastPacketReceivedTime = time.Now().Add(-time.Hour) done := make(chan struct{}) @@ -1327,7 +1327,7 @@ var _ = Describe("Session", func() { It("times out due to non-completed handshake", func() { sess.sessionCreationTime = time.Now().Add(-protocol.DefaultHandshakeTimeout).Add(-time.Second) - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sessionRunner.EXPECT().Remove(gomock.Any()) cryptoSetup.EXPECT().Close() done := make(chan struct{}) go func() { @@ -1359,7 +1359,7 @@ var _ = Describe("Session", func() { }() Consistently(sess.Context().Done()).ShouldNot(BeClosed()) // make the go routine return - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) cryptoSetup.EXPECT().Close() sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) @@ -1367,13 +1367,13 @@ var _ = Describe("Session", func() { It("closes the session due to the idle timeout after handshake", func() { packer.EXPECT().PackPacket().AnyTimes() - sessionRunner.EXPECT().removeConnectionID(gomock.Any()) + sessionRunner.EXPECT().Remove(gomock.Any()) cryptoSetup.EXPECT().Close() sess.config.IdleTimeout = 0 done := make(chan struct{}) go func() { defer GinkgoRecover() - sessionRunner.EXPECT().onHandshakeComplete(sess) + sessionRunner.EXPECT().OnHandshakeComplete(sess) cryptoSetup.EXPECT().RunHandshake() err := sess.run() nerr, ok := err.(net.Error) @@ -1398,7 +1398,7 @@ var _ = Describe("Session", func() { Consistently(sess.Context().Done()).ShouldNot(BeClosed()) // make the go routine return packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) cryptoSetup.EXPECT().Close() sess.Close() Eventually(sess.Context().Done()).Should(BeClosed()) @@ -1573,7 +1573,7 @@ var _ = Describe("Client Session", func() { }, []byte{0}))).To(BeTrue()) // make sure the go routine returns packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) cryptoSetup.EXPECT().Close() Expect(sess.Close()).To(Succeed()) Eventually(sess.Context().Done()).Should(BeClosed()) @@ -1639,7 +1639,7 @@ var _ = Describe("Client Session", func() { Expect(err.Error()).To(ContainSubstring("transport parameter")) }() // streamManager.EXPECT().CloseWithError(gomock.Any()) - sessionRunner.EXPECT().retireConnectionID(gomock.Any()) + sessionRunner.EXPECT().Retire(gomock.Any()) packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) cryptoSetup.EXPECT().Close() sess.processTransportParameters([]byte("invalid")) From dd8c590b13f03381d81a52b9c396601fad78da36 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 5 Mar 2019 14:16:07 +0900 Subject: [PATCH 4/9] add methods to add and remove reset tokens to the packet handler map --- packet_handler_map.go | 52 +++++++++++++++++--------------------- packet_handler_map_test.go | 29 ++++----------------- 2 files changed, 28 insertions(+), 53 deletions(-) diff --git a/packet_handler_map.go b/packet_handler_map.go index e89baf49..0ae65b51 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -11,11 +11,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) -type packetHandlerEntry struct { - handler packetHandler - resetToken *[16]byte -} - // The packetHandlerMap stores packetHandlers, identified by connection ID. // It is used: // * by the server to store sessions @@ -26,7 +21,7 @@ type packetHandlerMap struct { conn net.PacketConn connIDLen int - handlers map[string] /* string(ConnectionID)*/ packetHandlerEntry + handlers map[string] /* string(ConnectionID)*/ packetHandler resetTokens map[[16]byte] /* stateless reset token */ packetHandler server unknownPacketHandler @@ -45,7 +40,7 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger conn: conn, connIDLen: connIDLen, listening: make(chan struct{}), - handlers: make(map[string]packetHandlerEntry), + handlers: make(map[string]packetHandler), resetTokens: make(map[[16]byte]packetHandler), deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout, logger: logger, @@ -56,14 +51,7 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) { h.mutex.Lock() - h.handlers[string(id)] = packetHandlerEntry{handler: handler} - h.mutex.Unlock() -} - -func (h *packetHandlerMap) AddWithResetToken(id protocol.ConnectionID, handler packetHandler, token [16]byte) { - h.mutex.Lock() - h.handlers[string(id)] = packetHandlerEntry{handler: handler, resetToken: &token} - h.resetTokens[token] = handler + h.handlers[string(id)] = handler h.mutex.Unlock() } @@ -73,12 +61,7 @@ func (h *packetHandlerMap) Remove(id protocol.ConnectionID) { func (h *packetHandlerMap) removeByConnectionIDAsString(id string) { h.mutex.Lock() - if handlerEntry, ok := h.handlers[id]; ok { - if token := handlerEntry.resetToken; token != nil { - delete(h.resetTokens, *token) - } - delete(h.handlers, id) - } + delete(h.handlers, id) h.mutex.Unlock() } @@ -92,6 +75,18 @@ func (h *packetHandlerMap) retireByConnectionIDAsString(id string) { }) } +func (h *packetHandlerMap) AddResetToken(token [16]byte, handler packetHandler) { + h.mutex.Lock() + h.resetTokens[token] = handler + h.mutex.Unlock() +} + +func (h *packetHandlerMap) RemoveResetToken(token [16]byte) { + h.mutex.Lock() + delete(h.resetTokens, token) + h.mutex.Unlock() +} + func (h *packetHandlerMap) SetServer(s unknownPacketHandler) { h.mutex.Lock() h.server = s @@ -102,8 +97,7 @@ func (h *packetHandlerMap) CloseServer() { h.mutex.Lock() h.server = nil var wg sync.WaitGroup - for id, handlerEntry := range h.handlers { - handler := handlerEntry.handler + for id, handler := range h.handlers { if handler.getPerspective() == protocol.PerspectiveServer { wg.Add(1) go func(id string, handler packetHandler) { @@ -136,12 +130,12 @@ func (h *packetHandlerMap) close(e error) error { h.closed = true var wg sync.WaitGroup - for _, handlerEntry := range h.handlers { + for _, handler := range h.handlers { wg.Add(1) - go func(handlerEntry packetHandlerEntry) { - handlerEntry.handler.destroy(e) + go func(handler packetHandler) { + handler.destroy(e) wg.Done() - }(handlerEntry) + }(handler) } if h.server != nil { @@ -187,7 +181,7 @@ func (h *packetHandlerMap) handlePacket( return } - handlerEntry, handlerFound := h.handlers[string(connID)] + handler, handlerFound := h.handlers[string(connID)] p := &receivedPacket{ remoteAddr: addr, @@ -196,7 +190,7 @@ func (h *packetHandlerMap) handlePacket( data: data, } if handlerFound { // existing session - handlerEntry.handler.handlePacket(p) + handler.handlePacket(p) return } if data[0]&0x80 == 0 { diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 00a2d40e..1b078c1b 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -163,28 +163,10 @@ var _ = Describe("Packet Handler Map", func() { }) Context("stateless reset handling", func() { - It("handles packets for connections added with a reset token", func() { - packetHandler := NewMockPacketHandler(mockCtrl) - connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} - token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddWithResetToken(connID, packetHandler, token) - // first send a normal packet - handledPacket := make(chan struct{}) - packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - cid, err := wire.ParseConnectionID(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(cid).To(Equal(connID)) - close(handledPacket) - }) - conn.dataToRead <- getPacket(connID) - Eventually(handledPacket).Should(BeClosed()) - }) - It("handles stateless resets", func() { packetHandler := NewMockPacketHandler(mockCtrl) - connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddWithResetToken(connID, packetHandler, token) + handler.AddResetToken(token, packetHandler) packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) packet = append(packet, token[:]...) destroyed := make(chan struct{}) @@ -199,7 +181,7 @@ var _ = Describe("Packet Handler Map", func() { handler.connIDLen = 0 packetHandler := NewMockPacketHandler(mockCtrl) token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddWithResetToken(protocol.ConnectionID{}, packetHandler, token) + handler.AddResetToken(token, packetHandler) packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) packet = append(packet, token[:]...) destroyed := make(chan struct{}) @@ -210,13 +192,12 @@ var _ = Describe("Packet Handler Map", func() { Eventually(destroyed).Should(BeClosed()) }) - It("deletes reset tokens when the session is retired", func() { + It("deletes reset tokens", func() { handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddWithResetToken(connID, NewMockPacketHandler(mockCtrl), token) - handler.Retire(connID) - time.Sleep(scaleDuration(30 * time.Millisecond)) + handler.AddResetToken(token, NewMockPacketHandler(mockCtrl)) + handler.RemoveResetToken(token) handler.handlePacket(nil, nil, getPacket(connID)) // don't EXPECT any calls to handlePacket of the MockPacketHandler packet := append([]byte{0x40, 0xde, 0xca, 0xfb, 0xad, 0x99} /* short header packet */, make([]byte, 50)...) From 6cf20310fe0215a96ba3b79b86e4d78948d2afeb Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 5 Mar 2019 15:23:43 +0900 Subject: [PATCH 5/9] use an array for the stateless reset token in the transport parameters --- internal/handshake/crypto_setup_test.go | 6 ++++-- internal/handshake/transport_parameter_test.go | 18 ++++++++++-------- internal/handshake/transport_parameters.go | 18 +++++++++--------- server.go | 7 ++++--- session.go | 2 +- session_test.go | 6 +++--- 6 files changed, 31 insertions(+), 26 deletions(-) diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index c2390fd6..9d52b801 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -271,6 +271,7 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(err).ToNot(HaveOccurred()) sChunkChan, sInitialStream, sHandshakeStream := initStreams() + var token [16]byte server, err := NewCryptoSetupServer( sInitialStream, sHandshakeStream, @@ -279,7 +280,7 @@ var _ = Describe("Crypto Setup TLS", func() { &EncryptedExtensionsTransportParameters{ NegotiatedVersion: protocol.VersionTLS, SupportedVersions: []protocol.VersionNumber{protocol.VersionTLS}, - Parameters: TransportParameters{StatelessResetToken: bytes.Repeat([]byte{42}, 16)}, + Parameters: TransportParameters{StatelessResetToken: &token}, }, func([]byte) {}, serverConf, @@ -359,9 +360,10 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(err).ToNot(HaveOccurred()) sChunkChan, sInitialStream, sHandshakeStream := initStreams() + var token [16]byte sTransportParameters := &TransportParameters{ IdleTimeout: 0x1337 * time.Second, - StatelessResetToken: bytes.Repeat([]byte{42}, 16), + StatelessResetToken: &token, } server, err := NewCryptoSetupServer( sInitialStream, diff --git a/internal/handshake/transport_parameter_test.go b/internal/handshake/transport_parameter_test.go index 7e3d9b46..0d665417 100644 --- a/internal/handshake/transport_parameter_test.go +++ b/internal/handshake/transport_parameter_test.go @@ -24,9 +24,9 @@ var _ = Describe("Transport Parameters", func() { IdleTimeout: 42 * time.Second, OriginalConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, AckDelayExponent: 14, - StatelessResetToken: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe}, + StatelessResetToken: &[16]byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x00}, } - Expect(p.String()).To(Equal("&handshake.TransportParameters{OriginalConnectionID: 0xdeadbeef, InitialMaxStreamDataBidiLocal: 0x1234, InitialMaxStreamDataBidiRemote: 0x2345, InitialMaxStreamDataUni: 0x3456, InitialMaxData: 0x4567, MaxBidiStreams: 1337, MaxUniStreams: 7331, IdleTimeout: 42s, AckDelayExponent: 14, StatelessResetToken: 0xdeadbeefcafe}")) + Expect(p.String()).To(Equal("&handshake.TransportParameters{OriginalConnectionID: 0xdeadbeef, InitialMaxStreamDataBidiLocal: 0x1234, InitialMaxStreamDataBidiRemote: 0x2345, InitialMaxStreamDataUni: 0x3456, InitialMaxData: 0x4567, MaxBidiStreams: 1337, MaxUniStreams: 7331, IdleTimeout: 42s, AckDelayExponent: 14, StatelessResetToken: 0x112233445566778899aabbccddeeff00}")) }) getRandomValue := func() uint64 { @@ -36,6 +36,8 @@ var _ = Describe("Transport Parameters", func() { } It("marshals und unmarshals", func() { + var token [16]byte + rand.Read(token[:]) params := &TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()), InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), @@ -45,7 +47,7 @@ var _ = Describe("Transport Parameters", func() { MaxBidiStreams: getRandomValue(), MaxUniStreams: getRandomValue(), DisableMigration: true, - StatelessResetToken: bytes.Repeat([]byte{100}, 16), + StatelessResetToken: &token, OriginalConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, AckDelayExponent: 13, } @@ -68,9 +70,10 @@ var _ = Describe("Transport Parameters", func() { }) It("errors when the stateless_reset_token has the wrong length", func() { - params := &TransportParameters{StatelessResetToken: bytes.Repeat([]byte{100}, 15)} b := &bytes.Buffer{} - params.marshal(b) + utils.BigEndian.WriteUint16(b, uint16(statelessResetTokenParameterID)) + utils.BigEndian.WriteUint16(b, 15) + b.Write(make([]byte, 15)) p := &TransportParameters{} Expect(p.unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("wrong length for stateless_reset_token: 15 (expected 16)")) }) @@ -190,9 +193,8 @@ var _ = Describe("Transport Parameters", func() { }) It("errors if the client sent a stateless_reset_token", func() { - params := &TransportParameters{ - StatelessResetToken: make([]byte, 16), - } + var token [16]byte + params := &TransportParameters{StatelessResetToken: &token} b := &bytes.Buffer{} params.marshal(b) p := &TransportParameters{} diff --git a/internal/handshake/transport_parameters.go b/internal/handshake/transport_parameters.go index 69354f33..15cab96b 100644 --- a/internal/handshake/transport_parameters.go +++ b/internal/handshake/transport_parameters.go @@ -46,7 +46,7 @@ type TransportParameters struct { IdleTimeout time.Duration DisableMigration bool - StatelessResetToken []byte + StatelessResetToken *[16]byte OriginalConnectionID protocol.ConnectionID } @@ -94,9 +94,9 @@ func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective if paramLen != 16 { return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen) } - b := make([]byte, 16) - r.Read(b) - p.StatelessResetToken = b + var token [16]byte + r.Read(token[:]) + p.StatelessResetToken = &token case originalConnectionIDParameterID: if sentBy == protocol.PerspectiveClient { return errors.New("client sent an original_connection_id") @@ -216,10 +216,10 @@ func (p *TransportParameters) marshal(b *bytes.Buffer) { utils.BigEndian.WriteUint16(b, uint16(disableMigrationParameterID)) utils.BigEndian.WriteUint16(b, 0) } - if len(p.StatelessResetToken) > 0 { + if p.StatelessResetToken != nil { utils.BigEndian.WriteUint16(b, uint16(statelessResetTokenParameterID)) - utils.BigEndian.WriteUint16(b, uint16(len(p.StatelessResetToken))) // should always be 16 bytes - b.Write(p.StatelessResetToken) + utils.BigEndian.WriteUint16(b, 16) + b.Write(p.StatelessResetToken[:]) } // original_connection_id if p.OriginalConnectionID.Len() > 0 { @@ -232,9 +232,9 @@ func (p *TransportParameters) marshal(b *bytes.Buffer) { // String returns a string representation, intended for logging. func (p *TransportParameters) String() string { logString := "&handshake.TransportParameters{OriginalConnectionID: %s, InitialMaxStreamDataBidiLocal: %#x, InitialMaxStreamDataBidiRemote: %#x, InitialMaxStreamDataUni: %#x, InitialMaxData: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s, AckDelayExponent: %d" - if len(p.StatelessResetToken) > 0 { // the client never sends a stateless reset token + if p.StatelessResetToken != nil { // the client never sends a stateless reset token logString += ", StatelessResetToken: %#x" } logString += "}" - return fmt.Sprintf(logString, p.OriginalConnectionID, p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout, p.AckDelayExponent, p.StatelessResetToken) + return fmt.Sprintf(logString, p.OriginalConnectionID, p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout, p.AckDelayExponent, *p.StatelessResetToken) } diff --git a/server.go b/server.go index 59c534d7..4f7b418b 100644 --- a/server.go +++ b/server.go @@ -430,6 +430,8 @@ func (s *server) createNewSession( srcConnID protocol.ConnectionID, version protocol.VersionNumber, ) (quicSession, error) { + // TODO(#855): generate a real token + token := [16]byte{42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42} params := &handshake.TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData, InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData, @@ -440,9 +442,8 @@ func (s *server) createNewSession( MaxUniStreams: uint64(s.config.MaxIncomingUniStreams), AckDelayExponent: protocol.AckDelayExponent, DisableMigration: true, - // TODO(#855): generate a real token - StatelessResetToken: bytes.Repeat([]byte{42}, 16), - OriginalConnectionID: origDestConnID, + StatelessResetToken: &token, + OriginalConnectionID: origDestConnID, } sess, err := s.newSession( &conn{pconn: s.conn, currentAddr: remoteAddr}, diff --git a/session.go b/session.go index 70ccc649..bd8ba567 100644 --- a/session.go +++ b/session.go @@ -952,7 +952,7 @@ func (s *session) processTransportParametersForClient(data []byte) (*handshake.T params := &eetp.Parameters // check that the server sent a stateless reset token - if len(params.StatelessResetToken) == 0 { + if params.StatelessResetToken == nil { return nil, errors.New("server didn't send stateless_reset_token") } // check the Retry token diff --git a/session_test.go b/session_test.go index f263cba0..af5e6be2 100644 --- a/session_test.go +++ b/session_test.go @@ -1661,7 +1661,7 @@ var _ = Describe("Client Session", func() { SupportedVersions: []protocol.VersionNumber{sess.version}, Parameters: handshake.TransportParameters{ OriginalConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - StatelessResetToken: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + StatelessResetToken: &[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, }, } _, err := sess.processTransportParametersForClient(eetp.Marshal()) @@ -1675,7 +1675,7 @@ var _ = Describe("Client Session", func() { SupportedVersions: []protocol.VersionNumber{sess.version}, Parameters: handshake.TransportParameters{ OriginalConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, - StatelessResetToken: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + StatelessResetToken: &[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, }, } _, err := sess.processTransportParametersForClient(eetp.Marshal()) @@ -1687,7 +1687,7 @@ var _ = Describe("Client Session", func() { BeforeEach(func() { params = handshake.TransportParameters{ - StatelessResetToken: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + StatelessResetToken: &[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, OriginalConnectionID: sess.origDestConnID, } }) From 94046cdb4b2aed7826373616c6bb1a5c708d820f Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 6 Mar 2019 13:58:32 +0900 Subject: [PATCH 6/9] implement sending of stateless resets --- client.go | 3 ++- client_test.go | 22 +++++++++------- interface.go | 2 ++ mock_multiplexer_test.go | 8 +++--- mock_packet_handler_manager_test.go | 38 ++++++++++++++++++++++++++ mock_session_runner_test.go | 24 +++++++++++++++++ multiplexer.go | 27 ++++++++++++++----- multiplexer_test.go | 13 ++++++--- packet_handler_map.go | 41 ++++++++++++++++++++++++++--- packet_handler_map_test.go | 39 +++++++++++++++++++++------ server.go | 13 ++++++--- server_test.go | 12 +++++---- session.go | 3 +++ 13 files changed, 199 insertions(+), 46 deletions(-) diff --git a/client.go b/client.go index dc70677e..9ede391f 100644 --- a/client.go +++ b/client.go @@ -120,7 +120,7 @@ func dialContext( createdPacketConn bool, ) (Session, error) { config = populateClientConfig(config, createdPacketConn) - packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength) + packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength, config.StatelessResetKey) if err != nil { return nil, err } @@ -240,6 +240,7 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config { MaxIncomingStreams: maxIncomingStreams, MaxIncomingUniStreams: maxIncomingUniStreams, KeepAlive: config.KeepAlive, + StatelessResetKey: config.StatelessResetKey, } } diff --git a/client_test.go b/client_test.go index 4a857d06..87bd1a9e 100644 --- a/client_test.go +++ b/client_test.go @@ -127,7 +127,7 @@ var _ = Describe("Client", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) manager.EXPECT().Close() - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) remoteAddrChan := make(chan string, 1) newClientSession = func( @@ -157,7 +157,7 @@ var _ = Describe("Client", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) manager.EXPECT().Close() - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) hostnameChan := make(chan string, 1) newClientSession = func( @@ -186,7 +186,7 @@ var _ = Describe("Client", func() { It("returns after the handshake is complete", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) run := make(chan struct{}) newClientSession = func( @@ -222,7 +222,7 @@ var _ = Describe("Client", func() { It("returns an error that occurs while waiting for the connection to become secure", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) testErr := errors.New("early handshake error") newClientSession = func( @@ -256,7 +256,7 @@ var _ = Describe("Client", func() { It("closes the session when the context is canceled", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(gomock.Any(), gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) sessionRunning := make(chan struct{}) defer close(sessionRunning) @@ -304,7 +304,7 @@ var _ = Describe("Client", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(connID, gomock.Any()) manager.EXPECT().Retire(connID) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) var runner sessionRunner sess := NewMockQuicSession(mockCtrl) @@ -345,7 +345,7 @@ var _ = Describe("Client", func() { } manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(gomock.Any(), gomock.Any(), gomock.Any()).Return(manager, nil) manager.EXPECT().Add(gomock.Any(), gomock.Any()) var conn connection @@ -401,6 +401,7 @@ var _ = Describe("Client", func() { MaxIncomingStreams: 1234, MaxIncomingUniStreams: 4321, ConnectionIDLength: 13, + StatelessResetKey: []byte("foobar"), } c := populateClientConfig(config, false) Expect(c.HandshakeTimeout).To(Equal(1337 * time.Minute)) @@ -408,11 +409,12 @@ var _ = Describe("Client", func() { Expect(c.MaxIncomingStreams).To(Equal(1234)) Expect(c.MaxIncomingUniStreams).To(Equal(4321)) Expect(c.ConnectionIDLength).To(Equal(13)) + Expect(c.StatelessResetKey).To(Equal([]byte("foobar"))) }) It("errors when the Config contains an invalid version", func() { manager := NewMockPacketHandlerManager(mockCtrl) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) version := protocol.VersionNumber(0x1234) _, err := Dial(packetConn, nil, "localhost:1234", &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}}) @@ -456,7 +458,7 @@ var _ = Describe("Client", func() { It("creates new TLS sessions with the right parameters", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(connID, gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}} c := make(chan struct{}) @@ -508,7 +510,7 @@ var _ = Describe("Client", func() { It("returns an error that occurs during version negotiation", func() { manager := NewMockPacketHandlerManager(mockCtrl) manager.EXPECT().Add(connID, gomock.Any()) - mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any()).Return(manager, nil) + mockMultiplexer.EXPECT().AddConn(packetConn, gomock.Any(), gomock.Any()).Return(manager, nil) testErr := errors.New("early handshake error") newClientSession = func( diff --git a/interface.go b/interface.go index 0b80702f..4827c7d9 100644 --- a/interface.go +++ b/interface.go @@ -208,6 +208,8 @@ type Config struct { // If not set, it will default to 100. // If set to a negative value, it doesn't allow any unidirectional streams. MaxIncomingUniStreams int + // The StatelessResetKey is used to generate stateless reset tokens. + StatelessResetKey []byte // KeepAlive defines whether this peer will periodically send a packet to keep the connection alive. KeepAlive bool } diff --git a/mock_multiplexer_test.go b/mock_multiplexer_test.go index 52731b5a..7ca5a254 100644 --- a/mock_multiplexer_test.go +++ b/mock_multiplexer_test.go @@ -35,18 +35,18 @@ func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder { } // AddConn mocks base method -func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int) (packetHandlerManager, error) { +func (m *MockMultiplexer) AddConn(arg0 net.PacketConn, arg1 int, arg2 []byte) (packetHandlerManager, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddConn", arg0, arg1) + ret := m.ctrl.Call(m, "AddConn", arg0, arg1, arg2) ret0, _ := ret[0].(packetHandlerManager) ret1, _ := ret[1].(error) return ret0, ret1 } // AddConn indicates an expected call of AddConn -func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1 interface{}) *gomock.Call { +func (mr *MockMultiplexerMockRecorder) AddConn(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddConn", reflect.TypeOf((*MockMultiplexer)(nil).AddConn), arg0, arg1, arg2) } // RemoveConn mocks base method diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 257af2c7..3793d270 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -46,6 +46,18 @@ func (mr *MockPacketHandlerManagerMockRecorder) Add(arg0, arg1 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockPacketHandlerManager)(nil).Add), arg0, arg1) } +// AddResetToken mocks base method +func (m *MockPacketHandlerManager) AddResetToken(arg0 [16]byte, arg1 packetHandler) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddResetToken", arg0, arg1) +} + +// AddResetToken indicates an expected call of AddResetToken +func (mr *MockPacketHandlerManagerMockRecorder) AddResetToken(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).AddResetToken), arg0, arg1) +} + // Close mocks base method func (m *MockPacketHandlerManager) Close() error { m.ctrl.T.Helper() @@ -72,6 +84,20 @@ func (mr *MockPacketHandlerManagerMockRecorder) CloseServer() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseServer", reflect.TypeOf((*MockPacketHandlerManager)(nil).CloseServer)) } +// GetStatelessResetToken mocks base method +func (m *MockPacketHandlerManager) GetStatelessResetToken(arg0 protocol.ConnectionID) [16]byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0) + ret0, _ := ret[0].([16]byte) + return ret0 +} + +// GetStatelessResetToken indicates an expected call of GetStatelessResetToken +func (mr *MockPacketHandlerManagerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).GetStatelessResetToken), arg0) +} + // Remove mocks base method func (m *MockPacketHandlerManager) Remove(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() @@ -84,6 +110,18 @@ func (mr *MockPacketHandlerManagerMockRecorder) Remove(arg0 interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockPacketHandlerManager)(nil).Remove), arg0) } +// RemoveResetToken mocks base method +func (m *MockPacketHandlerManager) RemoveResetToken(arg0 [16]byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RemoveResetToken", arg0) +} + +// RemoveResetToken indicates an expected call of RemoveResetToken +func (mr *MockPacketHandlerManagerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockPacketHandlerManager)(nil).RemoveResetToken), arg0) +} + // Retire mocks base method func (m *MockPacketHandlerManager) Retire(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() diff --git a/mock_session_runner_test.go b/mock_session_runner_test.go index ab88320f..fad0e813 100644 --- a/mock_session_runner_test.go +++ b/mock_session_runner_test.go @@ -34,6 +34,18 @@ func (m *MockSessionRunner) EXPECT() *MockSessionRunnerMockRecorder { return m.recorder } +// AddResetToken mocks base method +func (m *MockSessionRunner) AddResetToken(arg0 [16]byte, arg1 packetHandler) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddResetToken", arg0, arg1) +} + +// AddResetToken indicates an expected call of AddResetToken +func (mr *MockSessionRunnerMockRecorder) AddResetToken(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockSessionRunner)(nil).AddResetToken), arg0, arg1) +} + // OnHandshakeComplete mocks base method func (m *MockSessionRunner) OnHandshakeComplete(arg0 Session) { m.ctrl.T.Helper() @@ -58,6 +70,18 @@ func (mr *MockSessionRunnerMockRecorder) Remove(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSessionRunner)(nil).Remove), arg0) } +// RemoveResetToken mocks base method +func (m *MockSessionRunner) RemoveResetToken(arg0 [16]byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RemoveResetToken", arg0) +} + +// RemoveResetToken indicates an expected call of RemoveResetToken +func (mr *MockSessionRunnerMockRecorder) RemoveResetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResetToken", reflect.TypeOf((*MockSessionRunner)(nil).RemoveResetToken), arg0) +} + // Retire mocks base method func (m *MockSessionRunner) Retire(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() diff --git a/multiplexer.go b/multiplexer.go index e8c3b7db..eeffca53 100644 --- a/multiplexer.go +++ b/multiplexer.go @@ -1,6 +1,7 @@ package quic import ( + "bytes" "fmt" "net" "sync" @@ -14,13 +15,14 @@ var ( ) type multiplexer interface { - AddConn(net.PacketConn, int) (packetHandlerManager, error) + AddConn(c net.PacketConn, connIDLen int, statelessResetKey []byte) (packetHandlerManager, error) RemoveConn(net.PacketConn) error } type connManager struct { - connIDLen int - manager packetHandlerManager + connIDLen int + statelessResetKey []byte + manager packetHandlerManager } // The connMultiplexer listens on multiple net.PacketConns and dispatches @@ -29,7 +31,7 @@ type connMultiplexer struct { mutex sync.Mutex conns map[net.PacketConn]connManager - newPacketHandlerManager func(net.PacketConn, int, utils.Logger) packetHandlerManager // so it can be replaced in the tests + newPacketHandlerManager func(net.PacketConn, int, []byte, utils.Logger) packetHandlerManager // so it can be replaced in the tests logger utils.Logger } @@ -47,19 +49,30 @@ func getMultiplexer() multiplexer { return connMuxer } -func (m *connMultiplexer) AddConn(c net.PacketConn, connIDLen int) (packetHandlerManager, error) { +func (m *connMultiplexer) AddConn( + c net.PacketConn, + connIDLen int, + statelessResetKey []byte, +) (packetHandlerManager, error) { m.mutex.Lock() defer m.mutex.Unlock() p, ok := m.conns[c] if !ok { - manager := m.newPacketHandlerManager(c, connIDLen, m.logger) - p = connManager{connIDLen: connIDLen, manager: manager} + manager := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, m.logger) + p = connManager{ + connIDLen: connIDLen, + statelessResetKey: statelessResetKey, + manager: manager, + } m.conns[c] = p } if p.connIDLen != connIDLen { return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) } + if statelessResetKey != nil && !bytes.Equal(p.statelessResetKey, statelessResetKey) { + return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn") + } return p.manager, nil } diff --git a/multiplexer_test.go b/multiplexer_test.go index f50f227f..1b40cf11 100644 --- a/multiplexer_test.go +++ b/multiplexer_test.go @@ -8,16 +8,23 @@ import ( var _ = Describe("Client Multiplexer", func() { It("adds a new packet conn ", func() { conn := newMockPacketConn() - _, err := getMultiplexer().AddConn(conn, 8) + _, err := getMultiplexer().AddConn(conn, 8, nil) Expect(err).ToNot(HaveOccurred()) }) It("errors when adding an existing conn with a different connection ID length", func() { conn := newMockPacketConn() - _, err := getMultiplexer().AddConn(conn, 5) + _, err := getMultiplexer().AddConn(conn, 5, nil) Expect(err).ToNot(HaveOccurred()) - _, err = getMultiplexer().AddConn(conn, 6) + _, err = getMultiplexer().AddConn(conn, 6, nil) Expect(err).To(MatchError("cannot use 6 byte connection IDs on a connection that is already using 5 byte connction IDs")) }) + It("errors when adding an existing conn with a different stateless rest key", func() { + conn := newMockPacketConn() + _, err := getMultiplexer().AddConn(conn, 7, []byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + _, err = getMultiplexer().AddConn(conn, 7, []byte("raboof")) + Expect(err).To(MatchError("cannot use different stateless reset keys on the same packet conn")) + }) }) diff --git a/packet_handler_map.go b/packet_handler_map.go index 0ae65b51..f7913a44 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -1,7 +1,11 @@ package quic import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" "errors" + "hash" "net" "sync" "time" @@ -30,12 +34,19 @@ type packetHandlerMap struct { deleteRetiredSessionsAfter time.Duration + statelessResetHasher hash.Hash + logger utils.Logger } var _ packetHandlerManager = &packetHandlerMap{} -func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger) packetHandlerManager { +func newPacketHandlerMap( + conn net.PacketConn, + connIDLen int, + statelessResetKey []byte, + logger utils.Logger, +) packetHandlerManager { m := &packetHandlerMap{ conn: conn, connIDLen: connIDLen, @@ -43,6 +54,7 @@ func newPacketHandlerMap(conn net.PacketConn, connIDLen int, logger utils.Logger handlers: make(map[string]packetHandler), resetTokens: make(map[[16]byte]packetHandler), deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout, + statelessResetHasher: hmac.New(sha256.New, statelessResetKey), logger: logger, } go m.listen() @@ -194,8 +206,7 @@ func (h *packetHandlerMap) handlePacket( return } if data[0]&0x80 == 0 { - // TODO(#943): send a stateless reset - h.logger.Debugf("received a short header packet with an unexpected connection ID %s", connID) + go h.maybeSendStatelessReset(p, connID) return } if h.server == nil { // no server set @@ -217,8 +228,30 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { var token [16]byte copy(token[:], data[len(data)-16:]) if sess, ok := h.resetTokens[token]; ok { - sess.destroy(errors.New("received a stateless reset")) + h.logger.Debugf("Received a stateless retry with token %#x. Closing session.", token) + go sess.destroy(errors.New("received a stateless reset")) return true } return false } + +func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) [16]byte { + h.statelessResetHasher.Write(connID.Bytes()) + var token [16]byte + copy(token[:], h.statelessResetHasher.Sum(nil)) + h.statelessResetHasher.Reset() + return token +} + +func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) { + defer p.buffer.Release() + token := h.GetStatelessResetToken(connID) + h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) + data := make([]byte, 23) + rand.Read(data) + data[0] = (data[0] & 0x7f) | 0x40 + data = append(data, token[:]...) + if _, err := h.conn.WriteTo(data, p.remoteAddr); err != nil { + h.logger.Debugf("Error sending Stateless Reset: %s", err) + } +} diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 1b078c1b..86de4df2 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -3,6 +3,7 @@ package quic import ( "bytes" "errors" + "net" "time" "github.com/golang/mock/gomock" @@ -40,7 +41,7 @@ var _ = Describe("Packet Handler Map", func() { BeforeEach(func() { conn = newMockPacketConn() - handler = newPacketHandlerMap(conn, 5, utils.DefaultLogger).(*packetHandlerMap) + handler = newPacketHandlerMap(conn, 5, nil, utils.DefaultLogger).(*packetHandlerMap) }) AfterEach(func() { @@ -163,6 +164,14 @@ var _ = Describe("Packet Handler Map", func() { }) Context("stateless reset handling", func() { + It("generates stateless reset tokens", func() { + connID1 := []byte{0xde, 0xad, 0xbe, 0xef} + connID2 := []byte{0xde, 0xca, 0xfb, 0xad} + token1 := handler.GetStatelessResetToken(connID1) + Expect(handler.GetStatelessResetToken(connID1)).To(Equal(token1)) + Expect(handler.GetStatelessResetToken(connID2)).ToNot(Equal(token1)) + }) + It("handles stateless resets", func() { packetHandler := NewMockPacketHandler(mockCtrl) token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} @@ -195,16 +204,30 @@ var _ = Describe("Packet Handler Map", func() { It("deletes reset tokens", func() { handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} + packetHandler := NewMockPacketHandler(mockCtrl) + handler.Add(connID, packetHandler) token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} handler.AddResetToken(token, NewMockPacketHandler(mockCtrl)) handler.RemoveResetToken(token) - handler.handlePacket(nil, nil, getPacket(connID)) - // don't EXPECT any calls to handlePacket of the MockPacketHandler - packet := append([]byte{0x40, 0xde, 0xca, 0xfb, 0xad, 0x99} /* short header packet */, make([]byte, 50)...) - packet = append(packet, token[:]...) - handler.handlePacket(nil, nil, packet) - // don't EXPECT any calls to handlePacket of the MockPacketHandler - Expect(handler.resetTokens).To(BeEmpty()) + packetHandler.EXPECT().handlePacket(gomock.Any()) + p := append([]byte{0x40} /* short header packet */, connID.Bytes()...) + p = append(p, make([]byte, 50)...) + p = append(p, token[:]...) + handler.handlePacket(nil, nil, p) + // destroy() would be called from a separate go routine + // make sure we give it enough time to be called to cause an error here + time.Sleep(scaleDuration(25 * time.Millisecond)) + }) + + It("sends stateless resets", func() { + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + p := append([]byte{40}, make([]byte, 100)...) + handler.handlePacket(addr, getPacketBuffer(), p) + var reset mockPacketConnWrite + Eventually(conn.dataWritten).Should(Receive(&reset)) + Expect(reset.to).To(Equal(addr)) + Expect(reset.data[0] & 0x80).To(BeZero()) // short header packet + Expect(reset.data).To(HaveLen(protocol.MinStatelessResetSize)) }) }) diff --git a/server.go b/server.go index 4f7b418b..fe2f9388 100644 --- a/server.go +++ b/server.go @@ -36,6 +36,9 @@ type packetHandlerManager interface { Add(protocol.ConnectionID, packetHandler) Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) + AddResetToken([16]byte, packetHandler) + RemoveResetToken([16]byte) + GetStatelessResetToken(protocol.ConnectionID) [16]byte SetServer(unknownPacketHandler) CloseServer() } @@ -55,6 +58,8 @@ type sessionRunner interface { OnHandshakeComplete(Session) Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) + AddResetToken([16]byte, packetHandler) + RemoveResetToken([16]byte) } type runner struct { @@ -143,7 +148,7 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, } } - sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength) + sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength, config.StatelessResetKey) if err != nil { return nil, err } @@ -266,6 +271,7 @@ func populateServerConfig(config *Config) *Config { MaxIncomingStreams: maxIncomingStreams, MaxIncomingUniStreams: maxIncomingUniStreams, ConnectionIDLength: connIDLen, + StatelessResetKey: config.StatelessResetKey, } } @@ -341,8 +347,8 @@ func (s *server) handlePacketImpl(p *receivedPacket) bool /* was the packet pass s.logger.Debugf("Error parsing packet: %s", err) return false } + // Short header packets should never end up here in the first place if !hdr.IsLongHeader { - // TODO: send a stateless reset return false } // send a Version Negotiation Packet if the client is speaking a different protocol version @@ -430,8 +436,7 @@ func (s *server) createNewSession( srcConnID protocol.ConnectionID, version protocol.VersionNumber, ) (quicSession, error) { - // TODO(#855): generate a real token - token := [16]byte{42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42} + token := s.sessionHandler.GetStatelessResetToken(srcConnID) params := &handshake.TransportParameters{ InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData, InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData, diff --git a/server_test.go b/server_test.go index 7313f32b..5fb9d291 100644 --- a/server_test.go +++ b/server_test.go @@ -79,11 +79,12 @@ var _ = Describe("Server", func() { supportedVersions := []protocol.VersionNumber{protocol.VersionTLS} acceptCookie := func(_ net.Addr, _ *Cookie) bool { return true } config := Config{ - Versions: supportedVersions, - AcceptCookie: acceptCookie, - HandshakeTimeout: 1337 * time.Hour, - IdleTimeout: 42 * time.Minute, - KeepAlive: true, + Versions: supportedVersions, + AcceptCookie: acceptCookie, + HandshakeTimeout: 1337 * time.Hour, + IdleTimeout: 42 * time.Minute, + KeepAlive: true, + StatelessResetKey: []byte("foobar"), } ln, err := Listen(conn, tlsConf, &config) Expect(err).ToNot(HaveOccurred()) @@ -94,6 +95,7 @@ var _ = Describe("Server", func() { Expect(server.config.IdleTimeout).To(Equal(42 * time.Minute)) Expect(reflect.ValueOf(server.config.AcceptCookie)).To(Equal(reflect.ValueOf(acceptCookie))) Expect(server.config.KeepAlive).To(BeTrue()) + Expect(server.config.StatelessResetKey).To(Equal([]byte("foobar"))) // stop the listener Expect(ln.Close()).To(Succeed()) }) diff --git a/session.go b/session.go index bd8ba567..74a0f5f5 100644 --- a/session.go +++ b/session.go @@ -927,6 +927,9 @@ func (s *session) processTransportParameters(data []byte) { s.packer.HandleTransportParameters(params) s.frameParser.SetAckDelayExponent(params.AckDelayExponent) s.connFlowController.UpdateSendWindow(params.InitialMaxData) + if params.StatelessResetToken != nil { + s.sessionRunner.AddResetToken(*params.StatelessResetToken, s) + } } func (s *session) processTransportParametersForClient(data []byte) (*handshake.TransportParameters, error) { From 19cf46d73686d62bf8383b42b9d6fe1e5b057795 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 6 Mar 2019 14:40:00 +0900 Subject: [PATCH 7/9] add an integration test for stateless resets --- integrationtests/self/stateless_reset_test.go | 106 ++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 integrationtests/self/stateless_reset_test.go diff --git a/integrationtests/self/stateless_reset_test.go b/integrationtests/self/stateless_reset_test.go new file mode 100644 index 00000000..2a565830 --- /dev/null +++ b/integrationtests/self/stateless_reset_test.go @@ -0,0 +1,106 @@ +package self + +import ( + "crypto/tls" + "fmt" + "math/rand" + "net" + "time" + + quic "github.com/lucas-clemente/quic-go" + quicproxy "github.com/lucas-clemente/quic-go/integrationtests/tools/proxy" + "github.com/lucas-clemente/quic-go/internal/testdata" + "github.com/lucas-clemente/quic-go/internal/utils" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Stateless Resets", func() { + connIDLens := []int{0, 10} + + for i := range connIDLens { + connIDLen := connIDLens[i] + + It(fmt.Sprintf("sends and recognizes stateless resets, for %d byte connection IDs", connIDLen), func() { + statelessResetKey := make([]byte, 32) + rand.Read(statelessResetKey) + serverConfig := &quic.Config{StatelessResetKey: statelessResetKey} + + ln, err := quic.ListenAddr("localhost:0", testdata.GetTLSConfig(), serverConfig) + Expect(err).ToNot(HaveOccurred()) + serverPort := ln.Addr().(*net.UDPAddr).Port + + closeServer := make(chan struct{}) + + go func() { + defer GinkgoRecover() + sess, err := ln.Accept() + Expect(err).ToNot(HaveOccurred()) + str, err := sess.OpenStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + <-closeServer + ln.Close() + }() + + drop := utils.AtomicBool{} + + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + DropPacket: func(d quicproxy.Direction, p uint64) bool { + return drop.Get() + }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + + sess, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + &tls.Config{RootCAs: testdata.GetRootCA()}, + &quic.Config{ + ConnectionIDLength: connIDLen, + IdleTimeout: 2 * time.Second, + }, + ) + Expect(err).ToNot(HaveOccurred()) + str, err := sess.AcceptStream() + Expect(err).ToNot(HaveOccurred()) + data := make([]byte, 6) + _, err = str.Read(data) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + + // make sure that the CONNECTION_CLOSE is dropped + drop.Set(true) + close(closeServer) + time.Sleep(100 * time.Millisecond) + + ln2, err := quic.ListenAddr( + fmt.Sprintf("localhost:%d", serverPort), + testdata.GetTLSConfig(), + serverConfig, + ) + Expect(err).ToNot(HaveOccurred()) + drop.Set(false) + + acceptStopped := make(chan struct{}) + go func() { + defer GinkgoRecover() + _, err := ln2.Accept() + Expect(err).To(HaveOccurred()) + close(acceptStopped) + }() + + // trigger something (not too small) to be sent, + // so that we receive the stateless reset + _, err = str.Write([]byte("Lorem ipsum dolor sit amet.")) + Expect(err).ToNot(HaveOccurred()) + _, err = str.Read([]byte{0}) + Expect(err).To(MatchError("INTERNAL_ERROR: received a stateless reset")) + + Expect(ln2.Close()).To(Succeed()) + Eventually(acceptStopped).Should(BeClosed()) + }) + } +}) From 5c20519743c6771d92390c4a0b662257f0b6bde6 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 6 Mar 2019 14:42:38 +0900 Subject: [PATCH 8/9] don't send stateless resets for small packets This prevents amplification and looping. --- packet_handler_map.go | 5 +++++ packet_handler_map_test.go | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/packet_handler_map.go b/packet_handler_map.go index f7913a44..43c3f24b 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -245,6 +245,11 @@ func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) { defer p.buffer.Release() + // Don't send a stateless reset in response to very small packets. + // This includes packets that could be stateless resets. + if len(p.data) <= protocol.MinStatelessResetSize { + return + } token := h.GetStatelessResetToken(connID) h.logger.Debugf("Sending stateless reset to %s (connection ID: %s). Token: %#x", p.remoteAddr, connID, token) data := make([]byte, 23) diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 86de4df2..051f518e 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -229,6 +229,13 @@ var _ = Describe("Packet Handler Map", func() { Expect(reset.data[0] & 0x80).To(BeZero()) // short header packet Expect(reset.data).To(HaveLen(protocol.MinStatelessResetSize)) }) + + It("doesn't send stateless resets for small packets", func() { + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...) + handler.handlePacket(addr, getPacketBuffer(), p) + Consistently(conn.dataWritten).ShouldNot(Receive()) + }) }) Context("running a server", func() { From b3fe0fdbf9022e802d6b223919124106c365f506 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 7 Mar 2019 09:25:46 +0900 Subject: [PATCH 9/9] only send stateless resets if a stateless reset key is configured --- interface.go | 1 + packet_handler_map.go | 16 +++- packet_handler_map_test.go | 188 ++++++++++++++++++++++--------------- 3 files changed, 127 insertions(+), 78 deletions(-) diff --git a/interface.go b/interface.go index 4827c7d9..6402d076 100644 --- a/interface.go +++ b/interface.go @@ -209,6 +209,7 @@ type Config struct { // If set to a negative value, it doesn't allow any unidirectional streams. MaxIncomingUniStreams int // The StatelessResetKey is used to generate stateless reset tokens. + // If no key is configured, sending of stateless resets is disabled. StatelessResetKey []byte // KeepAlive defines whether this peer will periodically send a packet to keep the connection alive. KeepAlive bool diff --git a/packet_handler_map.go b/packet_handler_map.go index 43c3f24b..0af6a407 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -34,7 +34,8 @@ type packetHandlerMap struct { deleteRetiredSessionsAfter time.Duration - statelessResetHasher hash.Hash + statelessResetEnabled bool + statelessResetHasher hash.Hash logger utils.Logger } @@ -54,6 +55,7 @@ func newPacketHandlerMap( handlers: make(map[string]packetHandler), resetTokens: make(map[[16]byte]packetHandler), deleteRetiredSessionsAfter: protocol.RetiredConnectionIDDeleteTimeout, + statelessResetEnabled: len(statelessResetKey) > 0, statelessResetHasher: hmac.New(sha256.New, statelessResetKey), logger: logger, } @@ -236,8 +238,15 @@ func (h *packetHandlerMap) maybeHandleStatelessReset(data []byte) bool { } func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) [16]byte { - h.statelessResetHasher.Write(connID.Bytes()) var token [16]byte + if !h.statelessResetEnabled { + // Return a random stateless reset token. + // This token will be sent in the server's transport parameters. + // By using a random token, an off-path attacker won't be able to disrupt the connection. + rand.Read(token[:]) + return token + } + h.statelessResetHasher.Write(connID.Bytes()) copy(token[:], h.statelessResetHasher.Sum(nil)) h.statelessResetHasher.Reset() return token @@ -245,6 +254,9 @@ func (h *packetHandlerMap) GetStatelessResetToken(connID protocol.ConnectionID) func (h *packetHandlerMap) maybeSendStatelessReset(p *receivedPacket, connID protocol.ConnectionID) { defer p.buffer.Release() + if !h.statelessResetEnabled { + return + } // Don't send a stateless reset in response to very small packets. // This includes packets that could be stateless resets. if len(p.data) <= protocol.MinStatelessResetSize { diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 051f518e..e52c1305 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "crypto/rand" "errors" "net" "time" @@ -18,6 +19,9 @@ var _ = Describe("Packet Handler Map", func() { var ( handler *packetHandlerMap conn *mockPacketConn + + connIDLen int + statelessResetKey []byte ) getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) []byte { @@ -40,8 +44,13 @@ var _ = Describe("Packet Handler Map", func() { } BeforeEach(func() { + statelessResetKey = nil + connIDLen = 0 + }) + + JustBeforeEach(func() { conn = newMockPacketConn() - handler = newPacketHandlerMap(conn, 5, nil, utils.DefaultLogger).(*packetHandlerMap) + handler = newPacketHandlerMap(conn, connIDLen, statelessResetKey, utils.DefaultLogger).(*packetHandlerMap) }) AfterEach(func() { @@ -80,6 +89,10 @@ var _ = Describe("Packet Handler Map", func() { }) Context("handling packets", func() { + BeforeEach(func() { + connIDLen = 5 + }) + It("handles packets for different packet handlers on the same packet conn", func() { connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} @@ -163,81 +176,6 @@ var _ = Describe("Packet Handler Map", func() { }) }) - Context("stateless reset handling", func() { - It("generates stateless reset tokens", func() { - connID1 := []byte{0xde, 0xad, 0xbe, 0xef} - connID2 := []byte{0xde, 0xca, 0xfb, 0xad} - token1 := handler.GetStatelessResetToken(connID1) - Expect(handler.GetStatelessResetToken(connID1)).To(Equal(token1)) - Expect(handler.GetStatelessResetToken(connID2)).ToNot(Equal(token1)) - }) - - It("handles stateless resets", func() { - packetHandler := NewMockPacketHandler(mockCtrl) - token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, packetHandler) - packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) - packet = append(packet, token[:]...) - destroyed := make(chan struct{}) - packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) { - close(destroyed) - }) - conn.dataToRead <- packet - Eventually(destroyed).Should(BeClosed()) - }) - - It("handles stateless resets for 0-length connection IDs", func() { - handler.connIDLen = 0 - packetHandler := NewMockPacketHandler(mockCtrl) - token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, packetHandler) - packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) - packet = append(packet, token[:]...) - destroyed := make(chan struct{}) - packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) { - close(destroyed) - }) - conn.dataToRead <- packet - Eventually(destroyed).Should(BeClosed()) - }) - - It("deletes reset tokens", func() { - handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} - packetHandler := NewMockPacketHandler(mockCtrl) - handler.Add(connID, packetHandler) - token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, NewMockPacketHandler(mockCtrl)) - handler.RemoveResetToken(token) - packetHandler.EXPECT().handlePacket(gomock.Any()) - p := append([]byte{0x40} /* short header packet */, connID.Bytes()...) - p = append(p, make([]byte, 50)...) - p = append(p, token[:]...) - handler.handlePacket(nil, nil, p) - // destroy() would be called from a separate go routine - // make sure we give it enough time to be called to cause an error here - time.Sleep(scaleDuration(25 * time.Millisecond)) - }) - - It("sends stateless resets", func() { - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - p := append([]byte{40}, make([]byte, 100)...) - handler.handlePacket(addr, getPacketBuffer(), p) - var reset mockPacketConnWrite - Eventually(conn.dataWritten).Should(Receive(&reset)) - Expect(reset.to).To(Equal(addr)) - Expect(reset.data[0] & 0x80).To(BeZero()) // short header packet - Expect(reset.data).To(HaveLen(protocol.MinStatelessResetSize)) - }) - - It("doesn't send stateless resets for small packets", func() { - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...) - handler.handlePacket(addr, getPacketBuffer(), p) - Consistently(conn.dataWritten).ShouldNot(Receive()) - }) - }) - Context("running a server", func() { It("adds a server", func() { connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} @@ -274,4 +212,102 @@ var _ = Describe("Packet Handler Map", func() { handler.handlePacket(nil, nil, p) }) }) + + Context("stateless resets", func() { + BeforeEach(func() { + connIDLen = 5 + }) + + Context("handling", func() { + It("handles stateless resets", func() { + packetHandler := NewMockPacketHandler(mockCtrl) + token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddResetToken(token, packetHandler) + packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) + packet = append(packet, token[:]...) + destroyed := make(chan struct{}) + packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) { + close(destroyed) + }) + conn.dataToRead <- packet + Eventually(destroyed).Should(BeClosed()) + }) + + It("handles stateless resets for 0-length connection IDs", func() { + handler.connIDLen = 0 + packetHandler := NewMockPacketHandler(mockCtrl) + token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddResetToken(token, packetHandler) + packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) + packet = append(packet, token[:]...) + destroyed := make(chan struct{}) + packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) { + close(destroyed) + }) + conn.dataToRead <- packet + Eventually(destroyed).Should(BeClosed()) + }) + + It("deletes reset tokens", func() { + handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) + connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} + packetHandler := NewMockPacketHandler(mockCtrl) + handler.Add(connID, packetHandler) + token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddResetToken(token, NewMockPacketHandler(mockCtrl)) + handler.RemoveResetToken(token) + packetHandler.EXPECT().handlePacket(gomock.Any()) + p := append([]byte{0x40} /* short header packet */, connID.Bytes()...) + p = append(p, make([]byte, 50)...) + p = append(p, token[:]...) + handler.handlePacket(nil, nil, p) + // destroy() would be called from a separate go routine + // make sure we give it enough time to be called to cause an error here + time.Sleep(scaleDuration(25 * time.Millisecond)) + }) + }) + + Context("generating", func() { + BeforeEach(func() { + key := make([]byte, 32) + rand.Read(key) + statelessResetKey = key + }) + + It("generates stateless reset tokens", func() { + connID1 := []byte{0xde, 0xad, 0xbe, 0xef} + connID2 := []byte{0xde, 0xca, 0xfb, 0xad} + token1 := handler.GetStatelessResetToken(connID1) + Expect(handler.GetStatelessResetToken(connID1)).To(Equal(token1)) + Expect(handler.GetStatelessResetToken(connID2)).ToNot(Equal(token1)) + }) + + It("sends stateless resets", func() { + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + p := append([]byte{40}, make([]byte, 100)...) + handler.handlePacket(addr, getPacketBuffer(), p) + var reset mockPacketConnWrite + Eventually(conn.dataWritten).Should(Receive(&reset)) + Expect(reset.to).To(Equal(addr)) + Expect(reset.data[0] & 0x80).To(BeZero()) // short header packet + Expect(reset.data).To(HaveLen(protocol.MinStatelessResetSize)) + }) + + It("doesn't send stateless resets for small packets", func() { + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...) + handler.handlePacket(addr, getPacketBuffer(), p) + Consistently(conn.dataWritten).ShouldNot(Receive()) + }) + }) + + Context("if no key is configured", func() { + It("doesn't send stateless resets", func() { + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + p := append([]byte{40}, make([]byte, 100)...) + handler.handlePacket(addr, getPacketBuffer(), p) + Consistently(conn.dataWritten).ShouldNot(Receive()) + }) + }) + }) })