diff --git a/conn_id_generator.go b/conn_id_generator.go index 07168955..33f977b8 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -16,30 +16,33 @@ type connIDGenerator struct { activeSrcConnIDs map[uint64]protocol.ConnectionID initialClientDestConnID protocol.ConnectionID - addConnectionID func(protocol.ConnectionID) [16]byte - removeConnectionID func(protocol.ConnectionID) - retireConnectionID func(protocol.ConnectionID) - replaceWithClosed func(protocol.ConnectionID, packetHandler) - queueControlFrame func(wire.Frame) + addConnectionID func(protocol.ConnectionID) + getStatelessResetToken func(protocol.ConnectionID) [16]byte + removeConnectionID func(protocol.ConnectionID) + retireConnectionID func(protocol.ConnectionID) + replaceWithClosed func(protocol.ConnectionID, packetHandler) + queueControlFrame func(wire.Frame) } func newConnIDGenerator( initialConnectionID protocol.ConnectionID, initialClientDestConnID protocol.ConnectionID, // nil for the client - addConnectionID func(protocol.ConnectionID) [16]byte, + addConnectionID func(protocol.ConnectionID), + getStatelessResetToken func(protocol.ConnectionID) [16]byte, removeConnectionID func(protocol.ConnectionID), retireConnectionID func(protocol.ConnectionID), replaceWithClosed func(protocol.ConnectionID, packetHandler), queueControlFrame func(wire.Frame), ) *connIDGenerator { m := &connIDGenerator{ - connIDLen: initialConnectionID.Len(), - activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), - addConnectionID: addConnectionID, - removeConnectionID: removeConnectionID, - retireConnectionID: retireConnectionID, - replaceWithClosed: replaceWithClosed, - queueControlFrame: queueControlFrame, + connIDLen: initialConnectionID.Len(), + activeSrcConnIDs: make(map[uint64]protocol.ConnectionID), + addConnectionID: addConnectionID, + getStatelessResetToken: getStatelessResetToken, + removeConnectionID: removeConnectionID, + retireConnectionID: retireConnectionID, + replaceWithClosed: replaceWithClosed, + queueControlFrame: queueControlFrame, } m.activeSrcConnIDs[0] = initialConnectionID m.initialClientDestConnID = initialClientDestConnID @@ -88,11 +91,11 @@ func (m *connIDGenerator) issueNewConnID() error { return err } m.activeSrcConnIDs[m.highestSeq+1] = connID - token := m.addConnectionID(connID) + m.addConnectionID(connID) m.queueControlFrame(&wire.NewConnectionIDFrame{ SequenceNumber: m.highestSeq + 1, ConnectionID: connID, - StatelessResetToken: token, + StatelessResetToken: m.getStatelessResetToken(connID), }) m.highestSeq++ return nil diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index 6ba7430c..1eeda3f6 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -20,6 +20,10 @@ var _ = Describe("Connection ID Generator", func() { initialConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7} initialClientDestConnID := protocol.ConnectionID{0xa, 0xb, 0xc, 0xd, 0xe} + connIDToToken := func(c protocol.ConnectionID) [16]byte { + return [16]byte{c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0], c[0]} + } + BeforeEach(func() { addedConnIDs = nil retiredConnIDs = nil @@ -29,11 +33,8 @@ var _ = Describe("Connection ID Generator", func() { g = newConnIDGenerator( initialConnID, initialClientDestConnID, - func(c protocol.ConnectionID) [16]byte { - addedConnIDs = append(addedConnIDs, c) - l := uint8(len(addedConnIDs)) - return [16]byte{l, l, l, l, l, l, l, l, l, l, l, l, l, l, l, l} - }, + func(c protocol.ConnectionID) { addedConnIDs = append(addedConnIDs, c) }, + connIDToToken, func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) }, func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) }, func(c protocol.ConnectionID, h packetHandler) { replacedWithClosed[string(c)] = h }, @@ -55,8 +56,7 @@ var _ = Describe("Connection ID Generator", func() { nf := f.(*wire.NewConnectionIDFrame) Expect(nf.SequenceNumber).To(BeEquivalentTo(i + 1)) Expect(nf.ConnectionID.Len()).To(Equal(7)) - j := uint8(i + 1) - Expect(nf.StatelessResetToken).To(Equal([16]byte{j, j, j, j, j, j, j, j, j, j, j, j, j, j, j, j})) + Expect(nf.StatelessResetToken).To(Equal(connIDToToken(nf.ConnectionID))) } }) diff --git a/mock_packet_handler_manager_test.go b/mock_packet_handler_manager_test.go index 4df0d310..13faa199 100644 --- a/mock_packet_handler_manager_test.go +++ b/mock_packet_handler_manager_test.go @@ -35,11 +35,9 @@ func (m *MockPacketHandlerManager) EXPECT() *MockPacketHandlerManagerMockRecorde } // Add mocks base method -func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) [16]byte { +func (m *MockPacketHandlerManager) Add(arg0 protocol.ConnectionID, arg1 packetHandler) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Add", arg0, arg1) - ret0, _ := ret[0].([16]byte) - return ret0 + m.ctrl.Call(m, "Add", arg0, arg1) } // Add indicates an expected call of Add diff --git a/mock_session_runner_test.go b/mock_session_runner_test.go index 709d5701..8eaeb750 100644 --- a/mock_session_runner_test.go +++ b/mock_session_runner_test.go @@ -35,11 +35,9 @@ func (m *MockSessionRunner) EXPECT() *MockSessionRunnerMockRecorder { } // Add mocks base method -func (m *MockSessionRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) [16]byte { +func (m *MockSessionRunner) Add(arg0 protocol.ConnectionID, arg1 packetHandler) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Add", arg0, arg1) - ret0, _ := ret[0].([16]byte) - return ret0 + m.ctrl.Call(m, "Add", arg0, arg1) } // Add indicates an expected call of Add @@ -60,6 +58,20 @@ func (mr *MockSessionRunnerMockRecorder) AddResetToken(arg0, arg1 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddResetToken", reflect.TypeOf((*MockSessionRunner)(nil).AddResetToken), arg0, arg1) } +// GetStatelessResetToken mocks base method +func (m *MockSessionRunner) GetStatelessResetToken(arg0 protocol.ConnectionID) [16]byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetStatelessResetToken", arg0) + ret0, _ := ret[0].([16]byte) + return ret0 +} + +// GetStatelessResetToken indicates an expected call of GetStatelessResetToken +func (mr *MockSessionRunnerMockRecorder) GetStatelessResetToken(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatelessResetToken", reflect.TypeOf((*MockSessionRunner)(nil).GetStatelessResetToken), arg0) +} + // Remove mocks base method func (m *MockSessionRunner) Remove(arg0 protocol.ConnectionID) { m.ctrl.T.Helper() diff --git a/packet_handler_map.go b/packet_handler_map.go index d05fe880..179d6583 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -95,11 +95,10 @@ func (h *packetHandlerMap) logUsage() { } } -func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) [16]byte { +func (h *packetHandlerMap) Add(id protocol.ConnectionID, handler packetHandler) { h.mutex.Lock() h.handlers[string(id)] = handler h.mutex.Unlock() - return h.GetStatelessResetToken(id) } func (h *packetHandlerMap) AddIfNotTaken(id protocol.ConnectionID, handler packetHandler) bool /* was added */ { diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 120bc5bd..e45147b6 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -295,14 +295,6 @@ var _ = Describe("Packet Handler Map", func() { statelessResetKey = key }) - It("generates stateless reset tokens when adding new sessions", func() { - connID1 := []byte{0xde, 0xad, 0xbe, 0xef} - connID2 := []byte{0xde, 0xca, 0xfb, 0xad} - token1 := handler.Add(connID1, nil) - Expect(handler.Add(connID1, nil)).To(Equal(token1)) - Expect(handler.Add(connID2, nil)).ToNot(Equal(token1)) - }) - It("generates stateless reset tokens", func() { connID1 := []byte{0xde, 0xad, 0xbe, 0xef} connID2 := []byte{0xde, 0xca, 0xfb, 0xad} diff --git a/server.go b/server.go index 8f1f381f..333f9d45 100644 --- a/server.go +++ b/server.go @@ -37,7 +37,6 @@ type packetHandlerManager interface { CloseServer() sessionRunner AddIfNotTaken(protocol.ConnectionID, packetHandler) bool - GetStatelessResetToken(protocol.ConnectionID) [16]byte } type quicSession interface { diff --git a/session.go b/session.go index 085a86df..e03aed8b 100644 --- a/session.go +++ b/session.go @@ -74,7 +74,8 @@ func (p *receivedPacket) Clone() *receivedPacket { } type sessionRunner interface { - Add(protocol.ConnectionID, packetHandler) [16]byte + Add(protocol.ConnectionID, packetHandler) + GetStatelessResetToken(protocol.ConnectionID) [16]byte Retire(protocol.ConnectionID) Remove(protocol.ConnectionID) ReplaceWithClosed(protocol.ConnectionID, packetHandler) @@ -232,7 +233,8 @@ var newSession = func( s.connIDGenerator = newConnIDGenerator( srcConnID, clientDestConnID, - func(connID protocol.ConnectionID) [16]byte { return runner.Add(connID, s) }, + func(connID protocol.ConnectionID) { runner.Add(connID, s) }, + runner.GetStatelessResetToken, runner.Remove, runner.Retire, runner.ReplaceWithClosed, @@ -335,7 +337,8 @@ var newClientSession = func( s.connIDGenerator = newConnIDGenerator( srcConnID, nil, - func(connID protocol.ConnectionID) [16]byte { return runner.Add(connID, s) }, + func(connID protocol.ConnectionID) { runner.Add(connID, s) }, + runner.GetStatelessResetToken, runner.Remove, runner.Retire, runner.ReplaceWithClosed, diff --git a/session_test.go b/session_test.go index 023bf7a3..5d5a2193 100644 --- a/session_test.go +++ b/session_test.go @@ -1318,6 +1318,7 @@ var _ = Describe("Session", func() { packer.EXPECT().HandleTransportParameters(params) packer.EXPECT().PackPacket().MaxTimes(3) Expect(sess.earlySessionReady()).ToNot(BeClosed()) + sessionRunner.EXPECT().GetStatelessResetToken(gomock.Any()).Times(2) sessionRunner.EXPECT().Add(gomock.Any(), sess).Times(2) sess.processTransportParameters(params) Expect(sess.earlySessionReady()).To(BeClosed())