From 5203b026e30729aeab8a10ced0436b2f93a01baa Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 26 Oct 2019 09:42:21 +0700 Subject: [PATCH] use the connection ID manager to save the server's stateless reset token --- conn_id_manager.go | 27 +++++++++++++++++++++------ conn_id_manager_test.go | 19 ++++++++++++++++--- session.go | 15 +++++++++++---- 3 files changed, 48 insertions(+), 13 deletions(-) diff --git a/conn_id_manager.go b/conn_id_manager.go index 7c2455c97..0b0f3759e 100644 --- a/conn_id_manager.go +++ b/conn_id_manager.go @@ -11,19 +11,24 @@ import ( type connIDManager struct { queue utils.NewConnectionIDList - activeSequenceNumber uint64 - activeConnectionID protocol.ConnectionID + activeSequenceNumber uint64 + activeConnectionID protocol.ConnectionID + activeStatelessResetToken *[16]byte - queueControlFrame func(wire.Frame) + addStatelessResetToken func([16]byte) + queueControlFrame func(wire.Frame) } func newConnIDManager( initialDestConnID protocol.ConnectionID, + addStatelessResetToken func([16]byte), queueControlFrame func(wire.Frame), ) *connIDManager { - h := &connIDManager{queueControlFrame: queueControlFrame} - h.activeConnectionID = initialDestConnID - return h + return &connIDManager{ + activeConnectionID: initialDestConnID, + addStatelessResetToken: addStatelessResetToken, + queueControlFrame: queueControlFrame, + } } func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error { @@ -96,6 +101,7 @@ func (h *connIDManager) updateConnectionID() { front := h.queue.Remove(h.queue.Front()) h.activeSequenceNumber = front.SequenceNumber h.activeConnectionID = front.ConnectionID + h.activeStatelessResetToken = front.StatelessResetToken } // is called when the server performs a Retry @@ -107,6 +113,15 @@ func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) { h.activeConnectionID = newConnID } +// is called when the server provides a stateless reset token in the transport parameters +func (h *connIDManager) SetStatelessResetToken(token [16]byte) { + if h.activeSequenceNumber != 0 { + panic("expected first connection ID to have sequence number 0") + } + h.activeStatelessResetToken = &token + h.addStatelessResetToken(token) +} + func (h *connIDManager) Get() protocol.ConnectionID { return h.activeConnectionID } diff --git a/conn_id_manager_test.go b/conn_id_manager_test.go index fa9fd2124..9f7ea0914 100644 --- a/conn_id_manager_test.go +++ b/conn_id_manager_test.go @@ -11,14 +11,20 @@ var _ = Describe("Connection ID Manager", func() { var ( m *connIDManager frameQueue []wire.Frame + tokenAdded *[16]byte ) initialConnID := protocol.ConnectionID{1, 1, 1, 1} BeforeEach(func() { frameQueue = nil - m = newConnIDManager(initialConnID, func(f wire.Frame) { - frameQueue = append(frameQueue, f) - }) + tokenAdded = nil + m = newConnIDManager( + initialConnID, + func(token [16]byte) { tokenAdded = &token }, + func(f wire.Frame, + ) { + frameQueue = append(frameQueue, f) + }) }) get := func() (protocol.ConnectionID, *[16]byte) { @@ -38,6 +44,13 @@ var _ = Describe("Connection ID Manager", func() { Expect(m.Get()).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5})) }) + It("sets the token for the first connection ID", func() { + token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + m.SetStatelessResetToken(token) + Expect(*m.activeStatelessResetToken).To(Equal(token)) + Expect(*tokenAdded).To(Equal(token)) + }) + It("adds and gets connection IDs", func() { Expect(m.Add(&wire.NewConnectionIDFrame{ SequenceNumber: 10, diff --git a/session.go b/session.go index f101262bb..39b292829 100644 --- a/session.go +++ b/session.go @@ -206,7 +206,11 @@ var newSession = func( logger: logger, version: v, } - s.connIDManager = newConnIDManager(destConnID, s.queueControlFrame) + s.connIDManager = newConnIDManager( + destConnID, + func(token [16]byte) { runner.AddResetToken(token, s) }, + s.queueControlFrame, + ) s.preSetup() s.sentPacketHandler = ackhandler.NewSentPacketHandler(0, s.rttStats, s.traceCallback, s.logger) initialStream := newCryptoStream() @@ -230,7 +234,6 @@ var newSession = func( logger, ) s.cryptoStreamHandler = cs - s.connIDManager = newConnIDManager(destConnID, s.queueControlFrame) s.packer = newPacketPacker( s.srcConnID, s.connIDManager.Get, @@ -275,7 +278,11 @@ var newClientSession = func( initialVersion: initialVersion, version: v, } - s.connIDManager = newConnIDManager(destConnID, s.queueControlFrame) + s.connIDManager = newConnIDManager( + destConnID, + func(token [16]byte) { runner.AddResetToken(token, s) }, + s.queueControlFrame, + ) s.preSetup() s.sentPacketHandler = ackhandler.NewSentPacketHandler(initialPacketNumber, s.rttStats, s.traceCallback, s.logger) initialStream := newCryptoStream() @@ -1023,7 +1030,7 @@ func (s *session) processTransportParameters(data []byte) { s.connFlowController.UpdateSendWindow(params.InitialMaxData) s.rttStats.SetMaxAckDelay(params.MaxAckDelay) if params.StatelessResetToken != nil { - s.sessionRunner.AddResetToken(*params.StatelessResetToken, s) + s.connIDManager.SetStatelessResetToken(*params.StatelessResetToken) } // On the server side, the early session is ready as soon as we processed // the client's transport parameters.