diff --git a/mint_utils.go b/mint_utils.go index 02bb3aef..9764a70c 100644 --- a/mint_utils.go +++ b/mint_utils.go @@ -128,14 +128,14 @@ func unpackInitialPacket(aead crypto.AEAD, hdr *wire.Header, data []byte, versio // packUnencryptedPacket provides a low-overhead way to pack a packet. // It is supposed to be used in the early stages of the handshake, before a session (which owns a packetPacker) is available. -func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, sf *wire.StreamFrame, pers protocol.Perspective) ([]byte, error) { +func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, f wire.Frame, pers protocol.Perspective) ([]byte, error) { raw := getPacketBuffer() buffer := bytes.NewBuffer(raw) if err := hdr.Write(buffer, pers, hdr.Version); err != nil { return nil, err } payloadStartIndex := buffer.Len() - if err := sf.Write(buffer, hdr.Version); err != nil { + if err := f.Write(buffer, hdr.Version); err != nil { return nil, err } raw = raw[0:buffer.Len()] @@ -144,7 +144,7 @@ func packUnencryptedPacket(aead crypto.AEAD, hdr *wire.Header, sf *wire.StreamFr if utils.Debug() { utils.Debugf("-> Sending packet 0x%x (%d bytes) for connection %x, %s", hdr.PacketNumber, len(raw), hdr.ConnectionID, protocol.EncryptionUnencrypted) hdr.Log() - wire.LogFrame(sf, true) + wire.LogFrame(f, true) } return raw, nil } diff --git a/server_tls.go b/server_tls.go index 881f4a13..c4a5fb1a 100644 --- a/server_tls.go +++ b/server_tls.go @@ -12,6 +12,7 @@ import ( "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" ) type nullAEAD struct { @@ -98,6 +99,26 @@ func (s *serverTLS) newMintConnImpl(bc *handshake.CryptoStreamConn, v protocol.V return tls, extHandler.GetPeerParams(), nil } +func (s *serverTLS) sendConnectionClose(remoteAddr net.Addr, clientHdr *wire.Header, aead crypto.AEAD, closeErr error) error { + ccf := &wire.ConnectionCloseFrame{ + ErrorCode: qerr.HandshakeFailed, + ReasonPhrase: closeErr.Error(), + } + replyHdr := &wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + ConnectionID: clientHdr.ConnectionID, // echo the client's connection ID + PacketNumber: 1, // random packet number + Version: clientHdr.Version, + } + data, err := packUnencryptedPacket(aead, replyHdr, ccf, protocol.PerspectiveServer) + if err != nil { + return err + } + _, err = s.conn.WriteTo(data, remoteAddr) + return err +} + func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, data []byte) (packetHandler, error) { if len(hdr.Raw)+len(data) < protocol.MinInitialPacketSize { return nil, errors.New("dropping too small Initial packet") @@ -110,19 +131,30 @@ func (s *serverTLS) handleInitialImpl(remoteAddr net.Addr, hdr *wire.Header, dat } // unpack packet and check stream frame contents - version := hdr.Version - aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, version) + aead, err := crypto.NewNullAEAD(protocol.PerspectiveServer, hdr.ConnectionID, hdr.Version) if err != nil { return nil, err } - frame, err := unpackInitialPacket(aead, hdr, data, version) + frame, err := unpackInitialPacket(aead, hdr, data, hdr.Version) if err != nil { utils.Debugf("Error unpacking initial packet: %s", err) return nil, nil } + sess, err := s.handleUnpackedInitial(remoteAddr, hdr, frame, aead) + if err != nil { + if ccerr := s.sendConnectionClose(remoteAddr, hdr, aead, err); ccerr != nil { + utils.Debugf("Error sending CONNECTION_CLOSE: ", ccerr) + } + return nil, err + } + return sess, nil +} + +func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header, frame *wire.StreamFrame, aead crypto.AEAD) (packetHandler, error) { + version := hdr.Version bc := handshake.NewCryptoStreamConn(remoteAddr) bc.AddDataForReading(frame.Data) - tls, paramsChan, err := s.newMintConn(bc, hdr.Version) + tls, paramsChan, err := s.newMintConn(bc, version) if err != nil { return nil, err } diff --git a/server_tls_test.go b/server_tls_test.go index 9c8d00ed..7cead5cd 100644 --- a/server_tls_test.go +++ b/server_tls_test.go @@ -4,15 +4,16 @@ import ( "bytes" "io" - "github.com/lucas-clemente/quic-go/internal/mocks" - "github.com/lucas-clemente/quic-go/internal/mocks/handshake" - "github.com/bifurcation/mint" "github.com/lucas-clemente/quic-go/internal/crypto" "github.com/lucas-clemente/quic-go/internal/handshake" + "github.com/lucas-clemente/quic-go/internal/mocks" + "github.com/lucas-clemente/quic-go/internal/mocks/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qerr" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -65,6 +66,18 @@ var _ = Describe("Stateless TLS handling", func() { return hdr, data } + unpackPacket := func(data []byte) (*wire.Header, []byte) { + r := bytes.NewReader(conn.dataWritten.Bytes()) + hdr, err := wire.ParseHeaderSentByServer(r, protocol.VersionTLS) + Expect(err).ToNot(HaveOccurred()) + hdr.Raw = data[:len(data)-r.Len()] + aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, hdr.ConnectionID, protocol.VersionTLS) + Expect(err).ToNot(HaveOccurred()) + payload, err := aead.Open(nil, data[len(data)-r.Len():], hdr.PacketNumber, hdr.Raw) + Expect(err).ToNot(HaveOccurred()) + return hdr, payload + } + It("sends a version negotiation packet if it doesn't support the version", func() { server.HandleInitial(nil, &wire.Header{Version: 0x1337}, bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize)) Expect(conn.dataWritten.Len()).ToNot(BeZero()) @@ -124,4 +137,20 @@ var _ = Describe("Stateless TLS handling", func() { Eventually(sessionChan).Should(Receive()) Eventually(done).Should(BeClosed()) }) + + It("sends a CONNECTION_CLOSE, if mint returns an error", func() { + mintTLS.EXPECT().Handshake().Return(mint.AlertAccessDenied) + extHandler.EXPECT().GetPeerParams() + hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")}) + server.HandleInitial(nil, hdr, data) + // the Handshake packet is written by the session + Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty()) + // unpack the packet to check that it actually contains a CONNECTION_CLOSE + hdr, data = unpackPacket(conn.dataWritten.Bytes()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeHandshake)) + ccf, err := wire.ParseConnectionCloseFrame(bytes.NewReader(data), protocol.VersionTLS) + Expect(err).ToNot(HaveOccurred()) + Expect(ccf.ErrorCode).To(Equal(qerr.HandshakeFailed)) + Expect(ccf.ReasonPhrase).To(Equal(mint.AlertAccessDenied.String())) + }) })