From 466d50345a8226571ecec8588fa4466c8c19cc7d Mon Sep 17 00:00:00 2001 From: Lucas Clemente Date: Tue, 17 May 2016 10:37:43 +0200 Subject: [PATCH] move error conversion from session to qerr.ToQuicError --- qerr/quic_error.go | 15 ++++++++++++++ qerr/quic_error_test.go | 33 ++++++++++++++++++++++++++++--- session.go | 43 +++++++++++++++++------------------------ 3 files changed, 63 insertions(+), 28 deletions(-) diff --git a/qerr/quic_error.go b/qerr/quic_error.go index 72e977241..d0a313827 100644 --- a/qerr/quic_error.go +++ b/qerr/quic_error.go @@ -2,6 +2,8 @@ package qerr import ( "fmt" + + "github.com/lucas-clemente/quic-go/utils" ) // ErrorCode can be used as a normal error without reason. @@ -28,3 +30,16 @@ func Error(errorCode ErrorCode, errorMessage string) *QuicError { func (e *QuicError) Error() string { return fmt.Sprintf("%s: %s", e.ErrorCode.String(), e.ErrorMessage) } + +// ToQuicError converts an arbitrary error to a QuicError. It leaves QuicErrors +// unchanged, and properly handles `ErrorCode`s. +func ToQuicError(err error) *QuicError { + switch e := err.(type) { + case *QuicError: + return e + case ErrorCode: + return Error(e, "") + } + utils.Errorf("BUG: Unknown error encountered: %#v", err) + return Error(InternalError, "") +} diff --git a/qerr/quic_error_test.go b/qerr/quic_error_test.go index bc255e3ae..f576d8367 100644 --- a/qerr/quic_error_test.go +++ b/qerr/quic_error_test.go @@ -1,6 +1,8 @@ package qerr_test import ( + "io" + "github.com/lucas-clemente/quic-go/qerr" . "github.com/onsi/ginkgo" @@ -8,8 +10,33 @@ import ( ) var _ = Describe("Quic error", func() { - It("has a string representation", func() { - err := qerr.Error(qerr.InternalError, "foobar") - Expect(err.Error()).To(Equal("InternalError: foobar")) + Context("QuicError", func() { + It("has a string representation", func() { + err := qerr.Error(qerr.DecryptionFailure, "foobar") + Expect(err.Error()).To(Equal("DecryptionFailure: foobar")) + }) + }) + + Context("ErrorCode", func() { + It("works as error", func() { + var err error = qerr.DecryptionFailure + Expect(err).To(MatchError("DecryptionFailure")) + }) + }) + + Context("ToQuicError", func() { + It("leaves QuicError unchanged", func() { + err := qerr.Error(qerr.DecryptionFailure, "foo") + Expect(qerr.ToQuicError(err)).To(Equal(err)) + }) + + It("wraps ErrorCode properly", func() { + var err error = qerr.DecryptionFailure + Expect(qerr.ToQuicError(err)).To(Equal(qerr.Error(qerr.DecryptionFailure, ""))) + }) + + It("changes default errors to InternalError", func() { + Expect(qerr.ToQuicError(io.EOF)).To(Equal(qerr.Error(qerr.InternalError, ""))) + }) }) }) diff --git a/session.go b/session.go index dd709297b..7a4d1b000 100644 --- a/session.go +++ b/session.go @@ -363,37 +363,17 @@ func (s *Session) Close(e error, sendConnectionClose bool) error { } if e == nil { - e = qerr.Error(qerr.PeerGoingAway, "peer going away") + e = qerr.PeerGoingAway } - utils.Errorf("Closing session with error: %s", e.Error()) - // if e is a QUIC error, send it to the client - // else, send the generic QUIC internal error - var errorCode qerr.ErrorCode - var reasonPhrase string - quicError, ok := e.(*qerr.QuicError) - if ok { - errorCode = quicError.ErrorCode - reasonPhrase = e.Error() - } else { - errorCode = qerr.InternalError - } + utils.Errorf("Closing session with error: %s", e.Error()) s.closeStreamsWithError(e) - if errorCode == qerr.DecryptionFailure { + quicErr := qerr.ToQuicError(e) + if quicErr.ErrorCode == qerr.DecryptionFailure { return s.sendPublicReset(s.lastRcvdPacketNumber) } - - packet, err := s.packer.PackPacket(nil, []frames.Frame{ - &frames.ConnectionCloseFrame{ErrorCode: errorCode, ReasonPhrase: reasonPhrase}, - }, false) - if err != nil { - return err - } - if packet == nil { - panic("Session: internal inconsistency: expected packet not to be nil") - } - return s.conn.write(packet.raw) + return s.sendConnectionClose(quicErr) } func (s *Session) closeStreamsWithError(err error) { @@ -537,6 +517,19 @@ func (s *Session) sendPacket() error { return nil } +func (s *Session) sendConnectionClose(quicErr *qerr.QuicError) error { + packet, err := s.packer.PackPacket(nil, []frames.Frame{ + &frames.ConnectionCloseFrame{ErrorCode: quicErr.ErrorCode, ReasonPhrase: quicErr.ErrorMessage}, + }, false) + if err != nil { + return err + } + if packet == nil { + panic("Session: internal inconsistency: expected packet not to be nil") + } + return s.conn.write(packet.raw) +} + // queueStreamFrame queues a frame for sending to the client func (s *Session) queueStreamFrame(frame *frames.StreamFrame) error { s.packer.AddStreamFrame(*frame)