diff --git a/server.go b/server.go index 1ddbba654..46b43c7cc 100644 --- a/server.go +++ b/server.go @@ -3,6 +3,7 @@ package quic import ( "bytes" "crypto/tls" + "errors" "net" "strings" "sync" @@ -153,10 +154,19 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet s.sessionsMutex.RUnlock() if !ok { - utils.Infof("Serving new connection: %x, version %d from %v", hdr.ConnectionID, hdr.VersionNumber, remoteAddr) + if !hdr.VersionFlag { + _, err = conn.WriteToUDP(writePublicReset(hdr.ConnectionID, hdr.PacketNumber, 0), remoteAddr) + return err + } + version := hdr.VersionNumber + if !protocol.IsSupportedVersion(version) { + return errors.New("Server BUG: negotiated version not supported") + } + + utils.Infof("Serving new connection: %x, version %d from %v", hdr.ConnectionID, version, remoteAddr) session, err = s.newSession( &udpConn{conn: conn, currentAddr: remoteAddr}, - hdr.VersionNumber, + version, hdr.ConnectionID, s.scfg, s.streamCallback, diff --git a/server_test.go b/server_test.go index 88d0c85a5..19e5356af 100644 --- a/server_test.go +++ b/server_test.go @@ -9,6 +9,7 @@ import ( "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/testdata" + "github.com/lucas-clemente/quic-go/utils" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -36,7 +37,8 @@ func newMockSession(conn connection, v protocol.VersionNumber, connectionID prot var _ = Describe("Server", func() { Describe("with mock session", func() { var ( - server *Server + server *Server + firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 ) BeforeEach(func() { @@ -44,6 +46,10 @@ var _ = Describe("Server", func() { sessions: map[protocol.ConnectionID]packetHandler{}, newSession: newMockSession, } + b := &bytes.Buffer{} + utils.WriteUint32(b, protocol.VersionNumberToTag(protocol.Version36)) + firstPacket = []byte{0x09, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c} + firstPacket = append(append(firstPacket, b.Bytes()...), 0x01) }) It("composes version negotiation packets", func() { @@ -55,7 +61,7 @@ var _ = Describe("Server", func() { }) It("creates new sessions", func() { - err := server.handlePacket(nil, nil, []byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x01}) + err := server.handlePacket(nil, nil, firstPacket) Expect(err).ToNot(HaveOccurred()) Expect(server.sessions).To(HaveLen(1)) Expect(server.sessions[0x4cfa9f9b668619f6].(*mockSession).connectionID).To(Equal(protocol.ConnectionID(0x4cfa9f9b668619f6))) @@ -63,7 +69,7 @@ var _ = Describe("Server", func() { }) It("assigns packets to existing sessions", func() { - err := server.handlePacket(nil, nil, []byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x01}) + err := server.handlePacket(nil, nil, firstPacket) Expect(err).ToNot(HaveOccurred()) err = server.handlePacket(nil, nil, []byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x01}) Expect(err).ToNot(HaveOccurred()) @@ -73,9 +79,7 @@ var _ = Describe("Server", func() { }) It("closes and deletes sessions", func() { - version := 0x34 - pheader := []byte{0x09, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x51, 0x30, 0x33, byte(version), 0x01} - err := server.handlePacket(nil, nil, append(pheader, (&crypto.NullAEAD{}).Seal(nil, nil, 0, pheader)...)) + err := server.handlePacket(nil, nil, append(firstPacket, (&crypto.NullAEAD{}).Seal(nil, nil, 0, firstPacket)...)) Expect(err).ToNot(HaveOccurred()) Expect(server.sessions).To(HaveLen(1)) server.closeCallback(0x4cfa9f9b668619f6) @@ -150,6 +154,42 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) }) + It("sends a public reset for new connections that don't have the VersionFlag set", func(done Done) { + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + Expect(err).ToNot(HaveOccurred()) + + server, err := NewServer("", testdata.GetTLSConfig(), nil) + Expect(err).ToNot(HaveOccurred()) + + serverConn, err := net.ListenUDP("udp", addr) + Expect(err).NotTo(HaveOccurred()) + + addr = serverConn.LocalAddr().(*net.UDPAddr) + + go func() { + defer GinkgoRecover() + err2 := server.Serve(serverConn) + Expect(err2).ToNot(HaveOccurred()) + close(done) + }() + + clientConn, err := net.DialUDP("udp", nil, addr) + Expect(err).ToNot(HaveOccurred()) + + _, err = clientConn.Write([]byte{0x08, 0xf6, 0x19, 0x86, 0x66, 0x9b, 0x9f, 0xfa, 0x4c, 0x01}) + Expect(err).ToNot(HaveOccurred()) + data := make([]byte, 1000) + var n int + n, _, err = clientConn.ReadFromUDP(data) + Expect(err).NotTo(HaveOccurred()) + Expect(n).ToNot(BeZero()) + Expect(data[0] & 0x02).ToNot(BeZero()) // check that the ResetFlag is set + Expect(server.sessions).To(BeEmpty()) + + err = server.Close() + Expect(err).ToNot(HaveOccurred()) + }) + It("setups and responds with error on invalid frame", func(done Done) { addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") Expect(err).ToNot(HaveOccurred())