diff --git a/closed_session_test.go b/closed_session_test.go index 9cabfeecd..2b55f12a7 100644 --- a/closed_session_test.go +++ b/closed_session_test.go @@ -4,6 +4,7 @@ import ( "errors" "time" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -14,11 +15,11 @@ import ( var _ = Describe("Closed local session", func() { var ( sess packetHandler - mconn *mockConnection + mconn *MockConnection ) BeforeEach(func() { - mconn = newMockConnection() + mconn = NewMockConnection(mockCtrl) sess = newClosedLocalSession(mconn, []byte("close"), protocol.PerspectiveClient, utils.DefaultLogger) }) @@ -33,12 +34,14 @@ var _ = Describe("Closed local session", func() { }) It("repeats the packet containing the CONNECTION_CLOSE frame", func() { + written := make(chan []byte) + mconn.EXPECT().Write(gomock.Any()).Do(func(p []byte) { written <- p }).AnyTimes() for i := 1; i <= 20; i++ { sess.handlePacket(&receivedPacket{}) if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 { - Eventually(mconn.written).Should(Receive(Equal([]byte("close")))) // receive the CONNECTION_CLOSE + Eventually(written).Should(Receive(Equal([]byte("close")))) // receive the CONNECTION_CLOSE } else { - Consistently(mconn.written, 10*time.Millisecond).Should(HaveLen(0)) + Consistently(written, 10*time.Millisecond).Should(HaveLen(0)) } } // stop the session