From 46546ce49738a54641dd701bbe14e41f8c51c025 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 29 Dec 2024 16:42:24 +0800 Subject: [PATCH] qerr: fix errors.Is for TransportError and ApplicationError (#4824) --- internal/qerr/errorcodes_test.go | 2 +- internal/qerr/errors.go | 45 +++++++++++----------- internal/qerr/errors_test.go | 65 +++++++++++++++++++++++++------- 3 files changed, 73 insertions(+), 39 deletions(-) diff --git a/internal/qerr/errorcodes_test.go b/internal/qerr/errorcodes_test.go index d3dcd53bf..78b9f436c 100644 --- a/internal/qerr/errorcodes_test.go +++ b/internal/qerr/errorcodes_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestTransportErrorStringer(t *testing.T) { +func TestTransportErrorCodeStringer(t *testing.T) { _, thisfile, _, ok := runtime.Caller(0) require.True(t, ok, "Failed to get current frame") diff --git a/internal/qerr/errors.go b/internal/qerr/errors.go index 0879bee4b..7fe1c293d 100644 --- a/internal/qerr/errors.go +++ b/internal/qerr/errors.go @@ -48,21 +48,16 @@ func (e *TransportError) Error() string { return str + ": " + msg } -func (e *TransportError) Is(target error) bool { - return target == net.ErrClosed -} +func (e *TransportError) Unwrap() []error { return []error{net.ErrClosed, e.error} } -func (e *TransportError) Unwrap() error { - return e.error +func (e *TransportError) Is(target error) bool { + t, ok := target.(*TransportError) + return ok && e.ErrorCode == t.ErrorCode && e.FrameType == t.FrameType && e.Remote == t.Remote } // An ApplicationErrorCode is an application-defined error code. type ApplicationErrorCode uint64 -func (e *ApplicationError) Is(target error) bool { - return target == net.ErrClosed -} - // A StreamErrorCode is an error code used to cancel streams. type StreamErrorCode uint64 @@ -81,23 +76,30 @@ func (e *ApplicationError) Error() string { return fmt.Sprintf("Application error %#x (%s): %s", e.ErrorCode, getRole(e.Remote), e.ErrorMessage) } +func (e *ApplicationError) Unwrap() error { return net.ErrClosed } + +func (e *ApplicationError) Is(target error) bool { + t, ok := target.(*ApplicationError) + return ok && e.ErrorCode == t.ErrorCode && e.Remote == t.Remote +} + type IdleTimeoutError struct{} var _ error = &IdleTimeoutError{} -func (e *IdleTimeoutError) Timeout() bool { return true } -func (e *IdleTimeoutError) Temporary() bool { return false } -func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" } -func (e *IdleTimeoutError) Is(target error) bool { return target == net.ErrClosed } +func (e *IdleTimeoutError) Timeout() bool { return true } +func (e *IdleTimeoutError) Temporary() bool { return false } +func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" } +func (e *IdleTimeoutError) Unwrap() error { return net.ErrClosed } type HandshakeTimeoutError struct{} var _ error = &HandshakeTimeoutError{} -func (e *HandshakeTimeoutError) Timeout() bool { return true } -func (e *HandshakeTimeoutError) Temporary() bool { return false } -func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" } -func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.ErrClosed } +func (e *HandshakeTimeoutError) Timeout() bool { return true } +func (e *HandshakeTimeoutError) Temporary() bool { return false } +func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" } +func (e *HandshakeTimeoutError) Unwrap() error { return net.ErrClosed } // A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version. type VersionNegotiationError struct { @@ -109,9 +111,7 @@ func (e *VersionNegotiationError) Error() string { return fmt.Sprintf("no compatible QUIC version found (we support %s, server offered %s)", e.Ours, e.Theirs) } -func (e *VersionNegotiationError) Is(target error) bool { - return target == net.ErrClosed -} +func (e *VersionNegotiationError) Unwrap() error { return net.ErrClosed } // A StatelessResetError occurs when we receive a stateless reset. type StatelessResetError struct{} @@ -122,10 +122,7 @@ func (e *StatelessResetError) Error() string { return "received a stateless reset" } -func (e *StatelessResetError) Is(target error) bool { - return target == net.ErrClosed -} - +func (e *StatelessResetError) Unwrap() error { return net.ErrClosed } func (e *StatelessResetError) Timeout() bool { return false } func (e *StatelessResetError) Temporary() bool { return true } diff --git a/internal/qerr/errors_test.go b/internal/qerr/errors_test.go index 5e39c1703..44ab88e70 100644 --- a/internal/qerr/errors_test.go +++ b/internal/qerr/errors_test.go @@ -11,7 +11,20 @@ import ( "github.com/stretchr/testify/require" ) -func TestTransportErrorStriner(t *testing.T) { +func TestTransportError(t *testing.T) { + require.True(t, errors.Is(&TransportError{}, net.ErrClosed)) + + require.True(t, errors.Is( + &TransportError{Remote: true, ErrorCode: 1, FrameType: 2}, + &TransportError{Remote: true, ErrorCode: 1, FrameType: 2}, + )) + require.True(t, errors.Is(&TransportError{ErrorCode: 0x101}, &TransportError{ErrorCode: 0x101})) + require.False(t, errors.Is(&TransportError{}, &TransportError{ErrorCode: 0x101})) + require.False(t, errors.Is(&TransportError{}, &TransportError{FrameType: 0x1337})) + require.False(t, errors.Is(&TransportError{Remote: false}, &TransportError{Remote: true})) +} + +func TestTransportErrorStringer(t *testing.T) { t.Run("with error message", func(t *testing.T) { err := &TransportError{ ErrorCode: FlowControlError, @@ -50,14 +63,26 @@ var _ error = myError(0) func (e myError) Error() string { return fmt.Sprintf("my error %d", e) } -func TestCryptoErrorUnwrapsErrors(t *testing.T) { +func TestCryptoError(t *testing.T) { var myErr myError err := NewLocalCryptoError(0x42, myError(1337)) require.True(t, errors.As(err, &myErr)) require.Equal(t, myError(1337), myErr) + + testErr := errors.New("test error") + err = NewLocalCryptoError(0x42, testErr) + require.True(t, errors.Is(err, testErr)) + require.True(t, errors.Is( + NewLocalCryptoError(0x42, testErr), + NewLocalCryptoError(0x42, testErr), + )) + require.False(t, errors.Is( + NewLocalCryptoError(0x42, testErr), + NewLocalCryptoError(0x43, testErr), + )) } -func TestCryptoErrorStringRepresentation(t *testing.T) { +func TestCryptoErrorStringer(t *testing.T) { t.Run("with error message", func(t *testing.T) { myErr := myError(1337) err := NewLocalCryptoError(0x42, myErr) @@ -71,6 +96,18 @@ func TestCryptoErrorStringRepresentation(t *testing.T) { } func TestApplicationError(t *testing.T) { + require.True(t, errors.Is(&ApplicationError{}, net.ErrClosed)) + + require.True(t, errors.Is( + &ApplicationError{ErrorCode: 1, Remote: true}, + &ApplicationError{ErrorCode: 1, Remote: true}, + )) + require.True(t, errors.Is(&ApplicationError{ErrorCode: 0x101}, &ApplicationError{ErrorCode: 0x101})) + require.False(t, errors.Is(&ApplicationError{}, &ApplicationError{ErrorCode: 0x101})) + require.False(t, errors.Is(&ApplicationError{Remote: false}, &ApplicationError{Remote: true})) +} + +func TestApplicationErrorStringer(t *testing.T) { t.Run("with error message", func(t *testing.T) { err := &ApplicationError{ ErrorCode: 0x42, @@ -89,6 +126,9 @@ func TestApplicationError(t *testing.T) { } func TestHandshakeTimeoutError(t *testing.T) { + require.True(t, errors.Is(&HandshakeTimeoutError{}, &HandshakeTimeoutError{})) + require.False(t, errors.Is(&HandshakeTimeoutError{}, &IdleTimeoutError{})) + //nolint:gosimple // we need to assign to an interface here var err error err = &HandshakeTimeoutError{} @@ -96,9 +136,13 @@ func TestHandshakeTimeoutError(t *testing.T) { require.True(t, ok) require.True(t, nerr.Timeout()) require.Equal(t, "timeout: handshake did not complete in time", err.Error()) + require.True(t, errors.Is(&HandshakeTimeoutError{}, net.ErrClosed)) } func TestIdleTimeoutError(t *testing.T) { + require.True(t, errors.Is(&IdleTimeoutError{}, &IdleTimeoutError{})) + require.False(t, errors.Is(&IdleTimeoutError{}, &HandshakeTimeoutError{})) + //nolint:gosimple // we need to assign to an interface here var err error err = &IdleTimeoutError{} @@ -106,6 +150,7 @@ func TestIdleTimeoutError(t *testing.T) { require.True(t, ok) require.True(t, nerr.Timeout()) require.Equal(t, "timeout: no recent network activity", err.Error()) + require.True(t, errors.Is(&IdleTimeoutError{}, net.ErrClosed)) } func TestVersionNegotiationErrorString(t *testing.T) { @@ -114,26 +159,18 @@ func TestVersionNegotiationErrorString(t *testing.T) { Theirs: []protocol.Version{4, 5, 6}, } require.Equal(t, "no compatible QUIC version found (we support [0x2 0x3], server offered [0x4 0x5 0x6])", err.Error()) + require.True(t, errors.Is(&VersionNegotiationError{}, net.ErrClosed)) } -func TestStatelessResetErrorString(t *testing.T) { +func TestStatelessResetError(t *testing.T) { require.Equal(t, "received a stateless reset", (&StatelessResetError{}).Error()) -} + require.True(t, errors.Is(&StatelessResetError{}, &StatelessResetError{})) -func TestStatelessResetErrorIsNetError(t *testing.T) { //nolint:gosimple // we need to assign to an interface here var err error err = &StatelessResetError{} nerr, ok := err.(net.Error) require.True(t, ok) require.False(t, nerr.Timeout()) -} - -func TestErrorsAreNetErrClosed(t *testing.T) { - require.True(t, errors.Is(&TransportError{}, net.ErrClosed)) - require.True(t, errors.Is(&ApplicationError{}, net.ErrClosed)) - require.True(t, errors.Is(&IdleTimeoutError{}, net.ErrClosed)) - require.True(t, errors.Is(&HandshakeTimeoutError{}, net.ErrClosed)) require.True(t, errors.Is(&StatelessResetError{}, net.ErrClosed)) - require.True(t, errors.Is(&VersionNegotiationError{}, net.ErrClosed)) }