diff --git a/client.go b/client.go index 1938635c..97f03bc6 100644 --- a/client.go +++ b/client.go @@ -233,6 +233,24 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error { c.mutex.Lock() defer c.mutex.Unlock() + if hdr.ResetFlag { + cr := c.conn.RemoteAddr() + // check if the remote address and the connection ID match + // otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection + if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || hdr.ConnectionID != c.connectionID { + utils.Infof("Received a spoofed Public Reset. Ignoring.") + return nil + } + pr, err := parsePublicReset(r) + if err != nil { + utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.") + return nil + } + utils.Infof("Received Public Reset, rejected packet number: %#x.", pr.rejectedPacketNumber) + c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.rejectedPacketNumber))) + return nil + } + // ignore delayed / duplicated version negotiation packets if c.versionNegotiated && hdr.VersionFlag { return nil diff --git a/client_test.go b/client_test.go index ece3a012..ce5d1856 100644 --- a/client_test.go +++ b/client_test.go @@ -30,11 +30,14 @@ var _ = Describe("Client", func() { Eventually(areSessionsRunning).Should(BeFalse()) msess, _, _ := newMockSession(nil, 0, 0, nil, nil, nil) sess = msess.(*mockSession) - packetConn = &mockPacketConn{addr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}} + addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} + packetConn = &mockPacketConn{ + addr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}, + dataReadFrom: addr, + } config = &Config{ Versions: []protocol.VersionNumber{protocol.SupportedVersions[0], 77, 78}, } - addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} cl = &client{ config: config, connectionID: 0x1337, @@ -379,7 +382,7 @@ var _ = Describe("Client", func() { It("closes the session when encountering an error while handling a packet", func() { Expect(sess.closeReason).ToNot(HaveOccurred()) - packetConn.dataToRead = bytes.Repeat([]byte{0xff}, 100) + packetConn.dataToRead = []byte("invalid packet") cl.listen() Expect(sess.closed).To(BeTrue()) Expect(sess.closeReason).To(HaveOccurred()) @@ -393,4 +396,37 @@ var _ = Describe("Client", func() { Expect(sess.closeReason).To(MatchError(testErr)) }) }) + + Context("Public Reset handling", func() { + It("closes the session when receiving a Public Reset", func() { + err := cl.handlePacket(addr, writePublicReset(cl.connectionID, 1, 0)) + Expect(err).ToNot(HaveOccurred()) + Expect(cl.session.(*mockSession).closed).To(BeTrue()) + Expect(cl.session.(*mockSession).closedRemote).To(BeTrue()) + Expect(cl.session.(*mockSession).closeReason.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PublicReset)) + }) + + It("ignores Public Resets with the wrong connection ID", func() { + err := cl.handlePacket(addr, writePublicReset(cl.connectionID+1, 1, 0)) + Expect(err).ToNot(HaveOccurred()) + Expect(cl.session.(*mockSession).closed).To(BeFalse()) + Expect(cl.session.(*mockSession).closedRemote).To(BeFalse()) + }) + + It("ignores Public Resets from the wrong remote address", func() { + spoofedAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 5678} + err := cl.handlePacket(spoofedAddr, writePublicReset(cl.connectionID, 1, 0)) + Expect(err).ToNot(HaveOccurred()) + Expect(cl.session.(*mockSession).closed).To(BeFalse()) + Expect(cl.session.(*mockSession).closedRemote).To(BeFalse()) + }) + + It("ignores unparseable Public Resets", func() { + pr := writePublicReset(cl.connectionID, 1, 0) + err := cl.handlePacket(addr, pr[:len(pr)-5]) + Expect(err).ToNot(HaveOccurred()) + Expect(cl.session.(*mockSession).closed).To(BeFalse()) + Expect(cl.session.(*mockSession).closedRemote).To(BeFalse()) + }) + }) }) diff --git a/server.go b/server.go index d45168b3..76f07bab 100644 --- a/server.go +++ b/server.go @@ -20,6 +20,7 @@ type packetHandler interface { Session handlePacket(*receivedPacket) run() error + closeRemote(error) } // A Listener of QUIC diff --git a/server_test.go b/server_test.go index 893754fe..5c3977d9 100644 --- a/server_test.go +++ b/server_test.go @@ -23,6 +23,7 @@ type mockSession struct { packetCount int closed bool closeReason error + closedRemote bool stopRunLoop chan struct{} // run returns as soon as this channel receives a value handshakeChan chan handshakeEvent handshakeComplete chan error // for WaitUntilHandshakeComplete @@ -51,6 +52,12 @@ func (s *mockSession) Close(e error) error { close(s.stopRunLoop) return nil } +func (s *mockSession) closeRemote(e error) { + s.closeReason = e + s.closed = true + s.closedRemote = true + close(s.stopRunLoop) +} func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") } diff --git a/session.go b/session.go index 0569df0c..2769f074 100644 --- a/session.go +++ b/session.go @@ -441,7 +441,7 @@ func (s *session) handleFrames(fs []frames.Frame) error { case *frames.AckFrame: err = s.handleAckFrame(frame) case *frames.ConnectionCloseFrame: - s.close(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true) + s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase)) case *frames.GoawayFrame: err = errors.New("unimplemented: handling GOAWAY frames") case *frames.StopWaitingFrame: @@ -527,20 +527,22 @@ func (s *session) handleAckFrame(frame *frames.AckFrame) error { return s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime) } -func (s *session) close(e error, remoteClose bool) { +func (s *session) closeLocal(e error) { s.closeOnce.Do(func() { - s.closeChan <- closeError{err: e, remote: remoteClose} + s.closeChan <- closeError{err: e, remote: false} }) } -func (s *session) closeLocal(e error) { - s.close(e, false) +func (s *session) closeRemote(e error) { + s.closeOnce.Do(func() { + s.closeChan <- closeError{err: e, remote: true} + }) } // Close the connection. If err is nil it will be set to qerr.PeerGoingAway. // It waits until the run loop has stopped before returning func (s *session) Close(e error) error { - s.close(e, false) + s.closeLocal(e) <-s.runClosed return nil }