qerr: fix errors.Is for TransportError and ApplicationError (#4824)

This commit is contained in:
Marten Seemann
2024-12-29 16:42:24 +08:00
committed by GitHub
parent f45daa0c05
commit 46546ce497
3 changed files with 73 additions and 39 deletions

View File

@@ -12,7 +12,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestTransportErrorStringer(t *testing.T) {
func TestTransportErrorCodeStringer(t *testing.T) {
_, thisfile, _, ok := runtime.Caller(0)
require.True(t, ok, "Failed to get current frame")

View File

@@ -48,21 +48,16 @@ func (e *TransportError) Error() string {
return str + ": " + msg
}
func (e *TransportError) Is(target error) bool {
return target == net.ErrClosed
}
func (e *TransportError) Unwrap() []error { return []error{net.ErrClosed, e.error} }
func (e *TransportError) Unwrap() error {
return e.error
func (e *TransportError) Is(target error) bool {
t, ok := target.(*TransportError)
return ok && e.ErrorCode == t.ErrorCode && e.FrameType == t.FrameType && e.Remote == t.Remote
}
// An ApplicationErrorCode is an application-defined error code.
type ApplicationErrorCode uint64
func (e *ApplicationError) Is(target error) bool {
return target == net.ErrClosed
}
// A StreamErrorCode is an error code used to cancel streams.
type StreamErrorCode uint64
@@ -81,23 +76,30 @@ func (e *ApplicationError) Error() string {
return fmt.Sprintf("Application error %#x (%s): %s", e.ErrorCode, getRole(e.Remote), e.ErrorMessage)
}
func (e *ApplicationError) Unwrap() error { return net.ErrClosed }
func (e *ApplicationError) Is(target error) bool {
t, ok := target.(*ApplicationError)
return ok && e.ErrorCode == t.ErrorCode && e.Remote == t.Remote
}
type IdleTimeoutError struct{}
var _ error = &IdleTimeoutError{}
func (e *IdleTimeoutError) Timeout() bool { return true }
func (e *IdleTimeoutError) Temporary() bool { return false }
func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" }
func (e *IdleTimeoutError) Is(target error) bool { return target == net.ErrClosed }
func (e *IdleTimeoutError) Timeout() bool { return true }
func (e *IdleTimeoutError) Temporary() bool { return false }
func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" }
func (e *IdleTimeoutError) Unwrap() error { return net.ErrClosed }
type HandshakeTimeoutError struct{}
var _ error = &HandshakeTimeoutError{}
func (e *HandshakeTimeoutError) Timeout() bool { return true }
func (e *HandshakeTimeoutError) Temporary() bool { return false }
func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" }
func (e *HandshakeTimeoutError) Is(target error) bool { return target == net.ErrClosed }
func (e *HandshakeTimeoutError) Timeout() bool { return true }
func (e *HandshakeTimeoutError) Temporary() bool { return false }
func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" }
func (e *HandshakeTimeoutError) Unwrap() error { return net.ErrClosed }
// A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version.
type VersionNegotiationError struct {
@@ -109,9 +111,7 @@ func (e *VersionNegotiationError) Error() string {
return fmt.Sprintf("no compatible QUIC version found (we support %s, server offered %s)", e.Ours, e.Theirs)
}
func (e *VersionNegotiationError) Is(target error) bool {
return target == net.ErrClosed
}
func (e *VersionNegotiationError) Unwrap() error { return net.ErrClosed }
// A StatelessResetError occurs when we receive a stateless reset.
type StatelessResetError struct{}
@@ -122,10 +122,7 @@ func (e *StatelessResetError) Error() string {
return "received a stateless reset"
}
func (e *StatelessResetError) Is(target error) bool {
return target == net.ErrClosed
}
func (e *StatelessResetError) Unwrap() error { return net.ErrClosed }
func (e *StatelessResetError) Timeout() bool { return false }
func (e *StatelessResetError) Temporary() bool { return true }

View File

@@ -11,7 +11,20 @@ import (
"github.com/stretchr/testify/require"
)
func TestTransportErrorStriner(t *testing.T) {
func TestTransportError(t *testing.T) {
require.True(t, errors.Is(&TransportError{}, net.ErrClosed))
require.True(t, errors.Is(
&TransportError{Remote: true, ErrorCode: 1, FrameType: 2},
&TransportError{Remote: true, ErrorCode: 1, FrameType: 2},
))
require.True(t, errors.Is(&TransportError{ErrorCode: 0x101}, &TransportError{ErrorCode: 0x101}))
require.False(t, errors.Is(&TransportError{}, &TransportError{ErrorCode: 0x101}))
require.False(t, errors.Is(&TransportError{}, &TransportError{FrameType: 0x1337}))
require.False(t, errors.Is(&TransportError{Remote: false}, &TransportError{Remote: true}))
}
func TestTransportErrorStringer(t *testing.T) {
t.Run("with error message", func(t *testing.T) {
err := &TransportError{
ErrorCode: FlowControlError,
@@ -50,14 +63,26 @@ var _ error = myError(0)
func (e myError) Error() string { return fmt.Sprintf("my error %d", e) }
func TestCryptoErrorUnwrapsErrors(t *testing.T) {
func TestCryptoError(t *testing.T) {
var myErr myError
err := NewLocalCryptoError(0x42, myError(1337))
require.True(t, errors.As(err, &myErr))
require.Equal(t, myError(1337), myErr)
testErr := errors.New("test error")
err = NewLocalCryptoError(0x42, testErr)
require.True(t, errors.Is(err, testErr))
require.True(t, errors.Is(
NewLocalCryptoError(0x42, testErr),
NewLocalCryptoError(0x42, testErr),
))
require.False(t, errors.Is(
NewLocalCryptoError(0x42, testErr),
NewLocalCryptoError(0x43, testErr),
))
}
func TestCryptoErrorStringRepresentation(t *testing.T) {
func TestCryptoErrorStringer(t *testing.T) {
t.Run("with error message", func(t *testing.T) {
myErr := myError(1337)
err := NewLocalCryptoError(0x42, myErr)
@@ -71,6 +96,18 @@ func TestCryptoErrorStringRepresentation(t *testing.T) {
}
func TestApplicationError(t *testing.T) {
require.True(t, errors.Is(&ApplicationError{}, net.ErrClosed))
require.True(t, errors.Is(
&ApplicationError{ErrorCode: 1, Remote: true},
&ApplicationError{ErrorCode: 1, Remote: true},
))
require.True(t, errors.Is(&ApplicationError{ErrorCode: 0x101}, &ApplicationError{ErrorCode: 0x101}))
require.False(t, errors.Is(&ApplicationError{}, &ApplicationError{ErrorCode: 0x101}))
require.False(t, errors.Is(&ApplicationError{Remote: false}, &ApplicationError{Remote: true}))
}
func TestApplicationErrorStringer(t *testing.T) {
t.Run("with error message", func(t *testing.T) {
err := &ApplicationError{
ErrorCode: 0x42,
@@ -89,6 +126,9 @@ func TestApplicationError(t *testing.T) {
}
func TestHandshakeTimeoutError(t *testing.T) {
require.True(t, errors.Is(&HandshakeTimeoutError{}, &HandshakeTimeoutError{}))
require.False(t, errors.Is(&HandshakeTimeoutError{}, &IdleTimeoutError{}))
//nolint:gosimple // we need to assign to an interface here
var err error
err = &HandshakeTimeoutError{}
@@ -96,9 +136,13 @@ func TestHandshakeTimeoutError(t *testing.T) {
require.True(t, ok)
require.True(t, nerr.Timeout())
require.Equal(t, "timeout: handshake did not complete in time", err.Error())
require.True(t, errors.Is(&HandshakeTimeoutError{}, net.ErrClosed))
}
func TestIdleTimeoutError(t *testing.T) {
require.True(t, errors.Is(&IdleTimeoutError{}, &IdleTimeoutError{}))
require.False(t, errors.Is(&IdleTimeoutError{}, &HandshakeTimeoutError{}))
//nolint:gosimple // we need to assign to an interface here
var err error
err = &IdleTimeoutError{}
@@ -106,6 +150,7 @@ func TestIdleTimeoutError(t *testing.T) {
require.True(t, ok)
require.True(t, nerr.Timeout())
require.Equal(t, "timeout: no recent network activity", err.Error())
require.True(t, errors.Is(&IdleTimeoutError{}, net.ErrClosed))
}
func TestVersionNegotiationErrorString(t *testing.T) {
@@ -114,26 +159,18 @@ func TestVersionNegotiationErrorString(t *testing.T) {
Theirs: []protocol.Version{4, 5, 6},
}
require.Equal(t, "no compatible QUIC version found (we support [0x2 0x3], server offered [0x4 0x5 0x6])", err.Error())
require.True(t, errors.Is(&VersionNegotiationError{}, net.ErrClosed))
}
func TestStatelessResetErrorString(t *testing.T) {
func TestStatelessResetError(t *testing.T) {
require.Equal(t, "received a stateless reset", (&StatelessResetError{}).Error())
}
require.True(t, errors.Is(&StatelessResetError{}, &StatelessResetError{}))
func TestStatelessResetErrorIsNetError(t *testing.T) {
//nolint:gosimple // we need to assign to an interface here
var err error
err = &StatelessResetError{}
nerr, ok := err.(net.Error)
require.True(t, ok)
require.False(t, nerr.Timeout())
}
func TestErrorsAreNetErrClosed(t *testing.T) {
require.True(t, errors.Is(&TransportError{}, net.ErrClosed))
require.True(t, errors.Is(&ApplicationError{}, net.ErrClosed))
require.True(t, errors.Is(&IdleTimeoutError{}, net.ErrClosed))
require.True(t, errors.Is(&HandshakeTimeoutError{}, net.ErrClosed))
require.True(t, errors.Is(&StatelessResetError{}, net.ErrClosed))
require.True(t, errors.Is(&VersionNegotiationError{}, net.ErrClosed))
}