diff --git a/connection_test.go b/connection_test.go index 51ea93e4..f79e5223 100644 --- a/connection_test.go +++ b/connection_test.go @@ -270,95 +270,87 @@ func TestConnectionServerInvalidFrames(t *testing.T) { } } -func TestConnectionTransportError(t *testing.T) { - mockCtrl := gomock.NewController(t) - tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) - tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr)) - errChan := make(chan error, 1) - expectedErr := &qerr.TransportError{ - ErrorCode: 1337, - FrameType: 42, - ErrorMessage: "foobar", - } - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - b := getPacketBuffer() - b.Data = append(b.Data, []byte("connection close")...) - tc.packer.EXPECT().PackConnectionClose(expectedErr, gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: b}, nil) - tc.sendConn.EXPECT().Write([]byte("connection close"), gomock.Any(), gomock.Any()) - tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - gomock.InOrder( - tracer.EXPECT().ClosedConnection(expectedErr), - tracer.EXPECT().Close(), - ) - - go func() { errChan <- tc.conn.run() }() - tc.conn.closeLocal(expectedErr) - - select { - case err := <-errChan: - require.ErrorIs(t, err, expectedErr) - case <-time.After(time.Second): - t.Fatal("timeout") - } - - // further calls to CloseWithError don't do anything - tc.conn.CloseWithError(42, "another error") +func TestConnectionClose(t *testing.T) { + t.Run("transport error", func(t *testing.T) { + expectedErr := &qerr.TransportError{ + ErrorCode: 1337, + FrameType: 42, + ErrorMessage: "foobar", + } + testConnectionClose(t, false, expectedErr) + }) + t.Run("application error", func(t *testing.T) { + expectedErr := &qerr.ApplicationError{ + ErrorCode: 1337, + ErrorMessage: "foobar", + } + testConnectionClose(t, true, expectedErr) + }) } -func TestConnectionApplicationClose(t *testing.T) { - mockCtrl := gomock.NewController(t) - tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) - tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr)) - errChan := make(chan error, 1) - expectedErr := &qerr.ApplicationError{ - ErrorCode: 1337, - ErrorMessage: "foobar", - } - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - b := getPacketBuffer() - b.Data = append(b.Data, []byte("connection close")...) - tc.packer.EXPECT().PackApplicationClose(expectedErr, gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: b}, nil) - tc.sendConn.EXPECT().Write([]byte("connection close"), gomock.Any(), gomock.Any()) - tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() - gomock.InOrder( - tracer.EXPECT().ClosedConnection(expectedErr), - tracer.EXPECT().Close(), - ) +func testConnectionClose(t *testing.T, useApplicationClose bool, expectedErr error) { + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) + tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr)) + errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - tc.conn.CloseWithError(1337, "foobar") + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + b := getPacketBuffer() + b.Data = append(b.Data, []byte("connection close")...) + if useApplicationClose { + tc.packer.EXPECT().PackApplicationClose(expectedErr, gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: b}, nil) + } else { + tc.packer.EXPECT().PackConnectionClose(expectedErr, gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: b}, nil) + } + tc.sendConn.EXPECT().Write([]byte("connection close"), gomock.Any(), gomock.Any()) + tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + gomock.InOrder( + tracer.EXPECT().ClosedConnection(expectedErr), + tracer.EXPECT().Close(), + ) - select { - case err := <-errChan: - require.ErrorIs(t, err, expectedErr) - case <-time.After(time.Second): - t.Fatal("timeout") - } + go func() { errChan <- tc.conn.run() }() + tc.conn.closeLocal(expectedErr) - // further calls to CloseWithError don't do anything - tc.conn.CloseWithError(42, "another error") + synctest.Wait() + + select { + case err := <-errChan: + require.ErrorIs(t, err, expectedErr) + default: + t.Fatal("connection was not closed") + } + + // further calls to CloseWithError don't do anything + tc.conn.CloseWithError(42, "another error") + }) } func TestConnectionStatelessReset(t *testing.T) { - mockCtrl := gomock.NewController(t) - tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) - tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr)) - errChan := make(chan error, 1) - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - gomock.InOrder( - tracer.EXPECT().ClosedConnection(&StatelessResetError{}), - tracer.EXPECT().Close(), - ) + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) + tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr)) + errChan := make(chan error, 1) + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + gomock.InOrder( + tracer.EXPECT().ClosedConnection(&StatelessResetError{}), + tracer.EXPECT().Close(), + ) - go func() { errChan <- tc.conn.run() }() - tc.conn.destroy(&StatelessResetError{}) + go func() { errChan <- tc.conn.run() }() + tc.conn.destroy(&StatelessResetError{}) - select { - case err := <-errChan: - require.ErrorIs(t, err, &StatelessResetError{}) - case <-time.After(time.Second): - t.Fatal("timeout") - } + synctest.Wait() + + select { + case err := <-errChan: + require.ErrorIs(t, err, &StatelessResetError{}) + default: + t.Fatal("connection was not closed") + } + }) } func getLongHeaderPacket(t *testing.T, remoteAddr net.Addr, extHdr *wire.ExtendedHeader, data []byte) receivedPacket { @@ -774,48 +766,52 @@ func TestConnectionMaxUnprocessedPackets(t *testing.T) { } func TestConnectionRemoteClose(t *testing.T) { - mockCtrl := gomock.NewController(t) - tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) - unpacker := NewMockUnpacker(mockCtrl) - tc := newServerTestConnection(t, - mockCtrl, - nil, - false, - connectionOptTracer(tr), - connectionOptUnpacker(unpacker), - ) - ccf, err := (&wire.ConnectionCloseFrame{ - ErrorCode: uint64(qerr.StreamLimitError), - ReasonPhrase: "foobar", - }).Append(nil, protocol.Version1) - require.NoError(t, err) - unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(1), protocol.PacketNumberLen2, protocol.KeyPhaseBit(0), ccf, nil) - tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl) + unpacker := NewMockUnpacker(mockCtrl) + tc := newServerTestConnection(t, + mockCtrl, + nil, + false, + connectionOptTracer(tr), + connectionOptUnpacker(unpacker), + ) + ccf, err := (&wire.ConnectionCloseFrame{ + ErrorCode: uint64(qerr.StreamLimitError), + ReasonPhrase: "foobar", + }).Append(nil, protocol.Version1) + require.NoError(t, err) + unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(1), protocol.PacketNumberLen2, protocol.KeyPhaseBit(0), ccf, nil) + tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - expectedErr := &qerr.TransportError{ErrorCode: qerr.StreamLimitError, Remote: true} - tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()) - tracerErrChan := make(chan error, 1) - tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { tracerErrChan <- e }) - tracer.EXPECT().Close() + expectedErr := &qerr.TransportError{ErrorCode: qerr.StreamLimitError, Remote: true} + tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any(), gomock.Any()) + tracerErrChan := make(chan error, 1) + tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { tracerErrChan <- e }) + tracer.EXPECT().Close() - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() - p := getShortHeaderPacket(t, tc.remoteAddr, tc.srcConnID, 1, []byte("encrypted")) - tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: monotime.Now()}) + p := getShortHeaderPacket(t, tc.remoteAddr, tc.srcConnID, 1, []byte("encrypted")) + tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: monotime.Now()}) - select { - case err := <-errChan: - require.ErrorIs(t, err, expectedErr) - case <-time.After(time.Second): - t.Fatal("timeout") - } - select { - case err := <-tracerErrChan: - require.ErrorIs(t, err, expectedErr) - case <-time.After(time.Second): - t.Fatal("timeout") - } + synctest.Wait() + + select { + case err := <-errChan: + require.ErrorIs(t, err, expectedErr) + default: + t.Fatal("connection was not closed") + } + select { + case err := <-tracerErrChan: + require.ErrorIs(t, err, expectedErr) + default: + t.Fatal("tracer didn't receive event") + } + }) } func TestConnectionIdleTimeoutDuringHandshake(t *testing.T) { @@ -1445,127 +1441,125 @@ func testConnectionReceivePrioritization(t *testing.T, handshakeComplete bool, n } func TestConnectionPacketBuffering(t *testing.T) { - mockCtrl := gomock.NewController(t) - unpacker := NewMockUnpacker(mockCtrl) - cs := mocks.NewMockCryptoSetup(mockCtrl) - tracer, tr := mocklogging.NewMockConnectionTracer(mockCtrl) - tc := newServerTestConnection(t, - mockCtrl, - nil, - false, - connectionOptUnpacker(unpacker), - connectionOptCryptoSetup(cs), - connectionOptTracer(tracer), - ) + synctest.Test(t, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + unpacker := NewMockUnpacker(mockCtrl) + cs := mocks.NewMockCryptoSetup(mockCtrl) + tracer, tr := mocklogging.NewMockConnectionTracer(mockCtrl) + tc := newServerTestConnection(t, + mockCtrl, + nil, + false, + connectionOptUnpacker(unpacker), + connectionOptCryptoSetup(cs), + connectionOptTracer(tracer), + ) - tr.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()) - tr.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) - tr.EXPECT().DroppedEncryptionLevel(gomock.Any()) - cs.EXPECT().DiscardInitialKeys() + tr.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any()) + tr.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + tr.EXPECT().DroppedEncryptionLevel(gomock.Any()) + cs.EXPECT().DiscardInitialKeys() - hdr1 := wire.ExtendedHeader{ - Header: wire.Header{ - Type: protocol.PacketTypeHandshake, - DestConnectionID: tc.srcConnID, - SrcConnectionID: tc.destConnID, - Length: 8, - Version: protocol.Version1, - }, - PacketNumberLen: protocol.PacketNumberLen1, - PacketNumber: 1, - } - hdr2 := hdr1 - hdr2.PacketNumber = 2 - cs.EXPECT().StartHandshake(gomock.Any()) - buffered := make(chan struct{}) - gomock.InOrder( - cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}), - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable), - tr.EXPECT().BufferedPacket(logging.PacketTypeHandshake, gomock.Any()), - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable), - tr.EXPECT().BufferedPacket(logging.PacketTypeHandshake, gomock.Any()).Do( - func(logging.PacketType, logging.ByteCount) { close(buffered) }, - ), - ) - - tc.conn.handlePacket(getLongHeaderPacket(t, tc.remoteAddr, &hdr1, []byte("packet1"))) - tc.conn.handlePacket(getLongHeaderPacket(t, tc.remoteAddr, &hdr2, []byte("packet2"))) - - errChan := make(chan error, 1) - go func() { errChan <- tc.conn.run() }() - - select { - case <-buffered: - case <-time.After(time.Second): - t.Fatal("timeout") - } - - // Now send another packet. - // In reality, this packet would contain a CRYPTO frame that advances the TLS handshake - // such that new keys become available. - var packets []string - hdr3 := hdr1 - hdr3.PacketNumber = 3 - tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - unpacked := make(chan struct{}) - cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventReceivedReadKeys}) - cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) - - gomock.InOrder( - // packet 3 contains a CRYPTO frame and triggers the keys to become available - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn( - func(hdr *wire.Header, data []byte) (*unpackedPacket, error) { - packets = append(packets, string(data[len(data)-7:])) - cf := &wire.CryptoFrame{Data: []byte("foobar")} - b, _ := cf.Append(nil, protocol.Version1) - return &unpackedPacket{hdr: &hdr3, encryptionLevel: protocol.EncryptionHandshake, data: b}, nil + hdr1 := wire.ExtendedHeader{ + Header: wire.Header{ + Type: protocol.PacketTypeHandshake, + DestConnectionID: tc.srcConnID, + SrcConnectionID: tc.destConnID, + Length: 8, + Version: protocol.Version1, }, - ), - cs.EXPECT().HandleMessage(gomock.Any(), gomock.Any()), - tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), - // packet 1 dequeued from the buffer - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn( - func(hdr *wire.Header, data []byte) (*unpackedPacket, error) { - packets = append(packets, string(data[len(data)-7:])) - return &unpackedPacket{hdr: &hdr1, encryptionLevel: protocol.EncryptionHandshake, data: []byte{0} /* PADDING */}, nil - }, - ), - tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), - // packet 2 dequeued from the buffer - unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn( - func(hdr *wire.Header, data []byte) (*unpackedPacket, error) { - packets = append(packets, string(data[len(data)-7:])) - close(unpacked) - return &unpackedPacket{hdr: &hdr2, encryptionLevel: protocol.EncryptionHandshake, data: []byte{0} /* PADDING */}, nil - }, - ), - tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), - ) + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 1, + } + hdr2 := hdr1 + hdr2.PacketNumber = 2 + cs.EXPECT().StartHandshake(gomock.Any()) + buffered := make(chan struct{}) + gomock.InOrder( + cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}), + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable), + tr.EXPECT().BufferedPacket(logging.PacketTypeHandshake, gomock.Any()), + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable), + tr.EXPECT().BufferedPacket(logging.PacketTypeHandshake, gomock.Any()).Do( + func(logging.PacketType, logging.ByteCount) { close(buffered) }, + ), + ) - tc.conn.handlePacket(getLongHeaderPacket(t, tc.remoteAddr, &hdr3, []byte("packet3"))) + tc.conn.handlePacket(getLongHeaderPacket(t, tc.remoteAddr, &hdr1, []byte("packet1"))) + tc.conn.handlePacket(getLongHeaderPacket(t, tc.remoteAddr, &hdr2, []byte("packet2"))) - select { - case <-unpacked: - case <-time.After(time.Second): - t.Fatal("timeout") - } + errChan := make(chan error, 1) + go func() { errChan <- tc.conn.run() }() - // packet3 triggered the keys to become available - // packet1 and packet2 are processed from the buffer in order - require.Equal(t, []string{"packet3", "packet1", "packet2"}, packets) + select { + case <-buffered: + case <-time.After(time.Second): + t.Fatal("timeout") + } - // test teardown - tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() - cs.EXPECT().Close() - tr.EXPECT().ClosedConnection(gomock.Any()) - tr.EXPECT().Close() - tc.conn.destroy(nil) - select { - case err := <-errChan: - require.NoError(t, err) - case <-time.After(time.Second): - t.Fatal("timeout") - } + // Now send another packet. + // In reality, this packet would contain a CRYPTO frame that advances the TLS handshake + // such that new keys become available. + var packets []string + hdr3 := hdr1 + hdr3.PacketNumber = 3 + tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + unpacked := make(chan struct{}) + cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventReceivedReadKeys}) + cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}) + + gomock.InOrder( + // packet 3 contains a CRYPTO frame and triggers the keys to become available + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn( + func(hdr *wire.Header, data []byte) (*unpackedPacket, error) { + packets = append(packets, string(data[len(data)-7:])) + cf := &wire.CryptoFrame{Data: []byte("foobar")} + b, _ := cf.Append(nil, protocol.Version1) + return &unpackedPacket{hdr: &hdr3, encryptionLevel: protocol.EncryptionHandshake, data: b}, nil + }, + ), + cs.EXPECT().HandleMessage(gomock.Any(), gomock.Any()), + tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), + // packet 1 dequeued from the buffer + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn( + func(hdr *wire.Header, data []byte) (*unpackedPacket, error) { + packets = append(packets, string(data[len(data)-7:])) + return &unpackedPacket{hdr: &hdr1, encryptionLevel: protocol.EncryptionHandshake, data: []byte{0} /* PADDING */}, nil + }, + ), + tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), + // packet 2 dequeued from the buffer + unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn( + func(hdr *wire.Header, data []byte) (*unpackedPacket, error) { + packets = append(packets, string(data[len(data)-7:])) + close(unpacked) + return &unpackedPacket{hdr: &hdr2, encryptionLevel: protocol.EncryptionHandshake, data: []byte{0} /* PADDING */}, nil + }, + ), + tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()), + ) + + tc.conn.handlePacket(getLongHeaderPacket(t, tc.remoteAddr, &hdr3, []byte("packet3"))) + + synctest.Wait() + + // packet3 triggered the keys to become available + // packet1 and packet2 are processed from the buffer in order + require.Equal(t, []string{"packet3", "packet1", "packet2"}, packets) + + // test teardown + tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes() + cs.EXPECT().Close() + tr.EXPECT().ClosedConnection(gomock.Any()) + tr.EXPECT().Close() + tc.conn.destroy(nil) + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("timeout") + } + }) } func TestConnectionPacketPacing(t *testing.T) {