From a528c4c4daf1402ca4d299f57657dbbccbc1bfd2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 7 Sep 2020 19:30:04 +0700 Subject: [PATCH] close session on errors unpacking errors other than decryption errors --- packet_unpacker.go | 26 ++++++++++++++++++++++++-- packet_unpacker_test.go | 41 ++++++++++++++++++++++++++++++++++------- session.go | 18 +++++++++++++++--- session_test.go | 34 ++++++++++++++++++++++++++++++---- 4 files changed, 103 insertions(+), 16 deletions(-) diff --git a/packet_unpacker.go b/packet_unpacker.go index ddf35408d..d4195fd62 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -15,6 +15,24 @@ type headerDecryptor interface { DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) } +type headerParseError struct { + err error +} + +func (e *headerParseError) Is(err error) bool { + _, ok := err.(*headerParseError) + return ok +} + +func (e *headerParseError) Unwrap() error { + fmt.Println("unwrap") + return e.err +} + +func (e *headerParseError) Error() string { + return e.err.Error() +} + type unpackedPacket struct { packetNumber protocol.PacketNumber // the decoded packet number hdr *wire.ExtendedHeader @@ -40,6 +58,9 @@ func newPacketUnpacker(cs handshake.CryptoSetup, version protocol.VersionNumber) } } +// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits. +// If any other error occurred when parsing the header, the error is of type headerParseError. +// If decrypting the payload fails for any reason, the error is the error returned by the AEAD. func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error) { var encLevel protocol.EncryptionLevel var extHdr *wire.ExtendedHeader @@ -107,7 +128,7 @@ func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpene // This avoids a timing side-channel, which otherwise might allow an attacker // to gain information about the header encryption. if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { - return nil, nil, fmt.Errorf("error parsing extended header: %s", parseErr) + return nil, nil, parseErr } extHdrLen := extHdr.ParsedLen() decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen]) @@ -144,10 +165,11 @@ func (u *packetUnpacker) unpackShortHeaderPacket( return extHdr, decrypted, nil } +// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError. func (u *packetUnpacker) unpackHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) { extHdr, err := unpackHeader(hd, hdr, data, u.version) if err != nil && err != wire.ErrInvalidReservedBits { - return nil, err + return nil, &headerParseError{err: err} } extHdr.PacketNumber = protocol.DecodePacketNumber( extHdr.PacketNumberLen, diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index 1f7fe2416..d68b0da2c 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -5,6 +5,8 @@ import ( "errors" "time" + "github.com/lucas-clemente/quic-go/internal/qerr" + "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/mocks" @@ -17,6 +19,7 @@ import ( var _ = Describe("Packet Unpacker", func() { const version = protocol.VersionTLS + var ( unpacker *packetUnpacker cs *mocks.MockCryptoSetup @@ -26,7 +29,7 @@ var _ = Describe("Packet Unpacker", func() { getHeader := func(extHdr *wire.ExtendedHeader) (*wire.Header, []byte) { buf := &bytes.Buffer{} - ExpectWithOffset(1, extHdr.Write(buf, protocol.VersionWhatever)).To(Succeed()) + ExpectWithOffset(1, extHdr.Write(buf, version)).To(Succeed()) hdrLen := buf.Len() if extHdr.Length > protocol.ByteCount(extHdr.PacketNumberLen) { buf.Write(make([]byte, int(extHdr.Length)-int(extHdr.PacketNumberLen))) @@ -41,7 +44,29 @@ var _ = Describe("Packet Unpacker", func() { unpacker = newPacketUnpacker(cs, version).(*packetUnpacker) }) - It("errors when the packet is too small to obtain the header decryption sample", func() { + It("errors when the packet is too small to obtain the header decryption sample, for long headers", func() { + extHdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: connID, + Version: version, + }, + PacketNumber: 1337, + PacketNumberLen: protocol.PacketNumberLen2, + } + hdr, hdrRaw := getHeader(extHdr) + data := append(hdrRaw, make([]byte, 2 /* fill up packet number */ +15 /* need 16 bytes */)...) + opener := mocks.NewMockLongHeaderOpener(mockCtrl) + cs.EXPECT().GetHandshakeOpener().Return(opener, nil) + _, err := unpacker.Unpack(hdr, time.Now(), data) + Expect(errors.Is(err, &headerParseError{})).To(BeTrue()) + var headerErr *headerParseError + Expect(errors.As(err, &headerErr)).To(BeTrue()) + Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19")) + }) + + It("errors when the packet is too small to obtain the header decryption sample, for short headers", func() { extHdr := &wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: connID}, PacketNumber: 1337, @@ -52,6 +77,9 @@ var _ = Describe("Packet Unpacker", func() { opener := mocks.NewMockShortHeaderOpener(mockCtrl) cs.EXPECT().Get1RTTOpener().Return(opener, nil) _, err := unpacker.Unpack(hdr, time.Now(), data) + Expect(errors.Is(err, &headerParseError{})).To(BeTrue()) + var headerErr *headerParseError + Expect(errors.As(err, &headerErr)).To(BeTrue()) Expect(err).To(MatchError("Packet too small. Expected at least 20 bytes after the header, got 19")) }) @@ -129,9 +157,9 @@ var _ = Describe("Packet Unpacker", func() { opener := mocks.NewMockLongHeaderOpener(mockCtrl) cs.EXPECT().GetHandshakeOpener().Return(opener, nil) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) - opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("test err")) + opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, qerr.CryptoBufferExceeded) _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) - Expect(err).To(MatchError("test err")) + Expect(err).To(MatchError(qerr.CryptoBufferExceeded)) }) It("defends against the timing side-channel when the reserved bits are wrong, for long header packets", func() { @@ -182,10 +210,9 @@ var _ = Describe("Packet Unpacker", func() { opener := mocks.NewMockShortHeaderOpener(mockCtrl) opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()) cs.EXPECT().Get1RTTOpener().Return(opener, nil) - testErr := errors.New("decryption error") - opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr) + opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed) _, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...)) - Expect(err).To(MatchError(testErr)) + Expect(err).To(MatchError(handshake.ErrDecryptionFailed)) }) It("decrypts the header", func() { diff --git a/session.go b/session.go index b1a489014..e813510e0 100644 --- a/session.go +++ b/session.go @@ -816,13 +816,25 @@ func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool / s.tryQueueingUndecryptablePacket(p, hdr) case wire.ErrInvalidReservedBits: s.closeLocal(qerr.NewError(qerr.ProtocolViolation, err.Error())) - default: - // This might be a packet injected by an attacker. - // Drop it. + case handshake.ErrDecryptionFailed: + // This might be a packet injected by an attacker. Drop it. if s.tracer != nil { s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropPayloadDecryptError) } s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", hdr.PacketType(), p.Size(), err) + default: + var headerErr *headerParseError + if errors.As(err, &headerErr) { + // This might be a packet injected by an attacker. Drop it. + if s.tracer != nil { + s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropHeaderParseError) + } + s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", hdr.PacketType(), p.Size(), err) + } else { + // This is an error returned by the AEAD (other than ErrDecryptionFailed). + // For example, a PROTOCOL_VIOLATION due to key updates. + s.closeLocal(err) + } } return false } diff --git a/session_test.go b/session_test.go index 07cb3e91b..68c271eb2 100644 --- a/session_test.go +++ b/session_test.go @@ -882,12 +882,11 @@ var _ = Describe("Session", func() { Eventually(sess.Context().Done()).Should(BeClosed()) }) - It("ignores packets when unpacking fails for any other reason", func() { - testErr := errors.New("test err") + It("ignores packets when unpacking the header fails", func() { + testErr := &headerParseError{errors.New("test error")} unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, testErr) streamManager.EXPECT().CloseWithError(gomock.Any()) cryptoSetup.EXPECT().Close() - packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) runErr := make(chan error) go func() { defer GinkgoRecover() @@ -895,13 +894,14 @@ var _ = Describe("Session", func() { runErr <- sess.run() }() expectReplaceWithClosed() - tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, gomock.Any(), logging.PacketDropPayloadDecryptError) + tracer.EXPECT().DroppedPacket(logging.PacketType1RTT, gomock.Any(), logging.PacketDropHeaderParseError) sess.handlePacket(getPacket(&wire.ExtendedHeader{ Header: wire.Header{DestConnectionID: srcConnID}, PacketNumberLen: protocol.PacketNumberLen1, }, nil)) Consistently(runErr).ShouldNot(Receive()) // make the go routine return + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() mconn.EXPECT().Write(gomock.Any()) @@ -909,6 +909,32 @@ var _ = Describe("Session", func() { Eventually(sess.Context().Done()).Should(BeClosed()) }) + It("closes the session when unpacking fails because of an error other than a decryption error", func() { + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, qerr.ConnectionIDLimitError) + streamManager.EXPECT().CloseWithError(gomock.Any()) + cryptoSetup.EXPECT().Close() + packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + err := sess.run() + Expect(err).To(HaveOccurred()) + Expect(err.(qerr.ErrorCode)).To(Equal(qerr.ConnectionIDLimitError)) + close(done) + }() + expectReplaceWithClosed() + mconn.EXPECT().Write(gomock.Any()) + packet := getPacket(&wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: srcConnID}, + PacketNumberLen: protocol.PacketNumberLen1, + }, nil) + tracer.EXPECT().ClosedConnection(gomock.Any()) + tracer.EXPECT().Close() + sess.handlePacket(packet) + Eventually(sess.Context().Done()).Should(BeClosed()) + }) + It("rejects packets with empty payload", func() { unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any(), gomock.Any()).Return(&unpackedPacket{ hdr: &wire.ExtendedHeader{},