diff --git a/conn_id_generator.go b/conn_id_generator.go index ab748f1af..a95110af8 100644 --- a/conn_id_generator.go +++ b/conn_id_generator.go @@ -18,6 +18,7 @@ type connIDGenerator struct { addConnectionID func(protocol.ConnectionID) [16]byte removeConnectionID func(protocol.ConnectionID) retireConnectionID func(protocol.ConnectionID) + replaceWithClosed func(protocol.ConnectionID, packetHandler) queueControlFrame func(wire.Frame) } @@ -26,6 +27,7 @@ func newConnIDGenerator( addConnectionID func(protocol.ConnectionID) [16]byte, removeConnectionID func(protocol.ConnectionID), retireConnectionID func(protocol.ConnectionID), + replaceWithClosed func(protocol.ConnectionID, packetHandler), queueControlFrame func(wire.Frame), ) *connIDGenerator { m := &connIDGenerator{ @@ -34,6 +36,7 @@ func newConnIDGenerator( addConnectionID: addConnectionID, removeConnectionID: removeConnectionID, retireConnectionID: retireConnectionID, + replaceWithClosed: replaceWithClosed, queueControlFrame: queueControlFrame, } m.activeSrcConnIDs[0] = initialConnectionID @@ -93,3 +96,9 @@ func (m *connIDGenerator) RemoveAll() { m.removeConnectionID(connID) } } + +func (m *connIDGenerator) ReplaceWithClosed(handler packetHandler) { + for _, connID := range m.activeSrcConnIDs { + m.replaceWithClosed(connID, handler) + } +} diff --git a/conn_id_generator_test.go b/conn_id_generator_test.go index c775d694f..81b363063 100644 --- a/conn_id_generator_test.go +++ b/conn_id_generator_test.go @@ -10,20 +10,23 @@ import ( var _ = Describe("Connection ID Generator", func() { var ( - addedConnIDs []protocol.ConnectionID - retiredConnIDs []protocol.ConnectionID - removedConnIDs []protocol.ConnectionID - queuedFrames []wire.Frame - g *connIDGenerator + addedConnIDs []protocol.ConnectionID + retiredConnIDs []protocol.ConnectionID + removedConnIDs []protocol.ConnectionID + replacedWithClosed map[string]packetHandler + queuedFrames []wire.Frame + g *connIDGenerator ) + initialConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7} BeforeEach(func() { addedConnIDs = nil retiredConnIDs = nil removedConnIDs = nil queuedFrames = nil + replacedWithClosed = make(map[string]packetHandler) g = newConnIDGenerator( - protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7}, + initialConnID, func(c protocol.ConnectionID) [16]byte { addedConnIDs = append(addedConnIDs, c) l := uint8(len(addedConnIDs)) @@ -31,6 +34,7 @@ var _ = Describe("Connection ID Generator", func() { }, func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) }, func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) }, + func(c protocol.ConnectionID, h packetHandler) { replacedWithClosed[string(c)] = h }, func(f wire.Frame) { queuedFrames = append(queuedFrames, f) }, ) }) @@ -81,7 +85,7 @@ var _ = Describe("Connection ID Generator", func() { Expect(g.Retire(0)).To(Succeed()) Expect(removedConnIDs).To(BeEmpty()) Expect(retiredConnIDs).To(HaveLen(1)) - Expect(retiredConnIDs[0]).To(Equal(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7})) + Expect(retiredConnIDs[0]).To(Equal(initialConnID)) Expect(addedConnIDs).To(BeEmpty()) }) @@ -101,11 +105,24 @@ var _ = Describe("Connection ID Generator", func() { Expect(g.SetMaxActiveConnIDs(5)).To(Succeed()) Expect(queuedFrames).To(HaveLen(5)) g.RemoveAll() - Expect(removedConnIDs).To(HaveLen(6)) // initial connection ID and newly issued ones - Expect(removedConnIDs).To(ContainElement(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7})) // initial connection ID + Expect(removedConnIDs).To(HaveLen(6)) // initial connection ID and newly issued ones + Expect(removedConnIDs).To(ContainElement(initialConnID)) for _, f := range queuedFrames { nf := f.(*wire.NewConnectionIDFrame) Expect(removedConnIDs).To(ContainElement(nf.ConnectionID)) } }) + + It("replaces with a closed session for all connection IDs", func() { + Expect(g.SetMaxActiveConnIDs(5)).To(Succeed()) + Expect(queuedFrames).To(HaveLen(5)) + sess := NewMockPacketHandler(mockCtrl) + g.ReplaceWithClosed(sess) + Expect(replacedWithClosed).To(HaveLen(6)) // initial connection ID and newly issued ones + Expect(replacedWithClosed).To(HaveKeyWithValue(string(initialConnID), sess)) + for _, f := range queuedFrames { + nf := f.(*wire.NewConnectionIDFrame) + Expect(replacedWithClosed).To(HaveKeyWithValue(string(nf.ConnectionID), sess)) + } + }) })