diff --git a/client_test.go b/client_test.go index 1d8e1aa5..c4be1346 100644 --- a/client_test.go +++ b/client_test.go @@ -321,7 +321,7 @@ var _ = Describe("Client", func() { newVersion := protocol.VersionNumber(77) Expect(newVersion).ToNot(Equal(cl.version)) Expect(config.Versions).To(ContainElement(newVersion)) - packetConn.dataToRead = wire.ComposeVersionNegotiation( + packetConn.dataToRead = wire.ComposeGQUICVersionNegotiation( cl.connectionID, []protocol.VersionNumber{newVersion}, ) @@ -395,17 +395,17 @@ var _ = Describe("Client", func() { newVersion := protocol.VersionNumber(77) Expect(newVersion).ToNot(Equal(cl.version)) Expect(config.Versions).To(ContainElement(newVersion)) - cl.handlePacket(nil, wire.ComposeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) + cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) Expect(atomic.LoadUint32(&sessionCounter)).To(BeEquivalentTo(2)) newVersion = protocol.VersionNumber(78) Expect(newVersion).ToNot(Equal(cl.version)) Expect(config.Versions).To(ContainElement(newVersion)) - cl.handlePacket(nil, wire.ComposeVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) + cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{newVersion})) Expect(atomic.LoadUint32(&sessionCounter)).To(BeEquivalentTo(2)) }) It("errors if no matching version is found", func() { - cl.handlePacket(nil, wire.ComposeVersionNegotiation(0x1337, []protocol.VersionNumber{1})) + cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{1})) Expect(cl.session.(*mockSession).closed).To(BeTrue()) Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion)) }) @@ -414,13 +414,13 @@ var _ = Describe("Client", func() { v := protocol.SupportedVersions[1] Expect(v).ToNot(Equal(cl.version)) Expect(config.Versions).ToNot(ContainElement(v)) - cl.handlePacket(nil, wire.ComposeVersionNegotiation(0x1337, []protocol.VersionNumber{v})) + cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{v})) Expect(cl.session.(*mockSession).closed).To(BeTrue()) Expect(cl.session.(*mockSession).closeReason).To(MatchError(qerr.InvalidVersion)) }) It("changes to the version preferred by the quic.Config", func() { - cl.handlePacket(nil, wire.ComposeVersionNegotiation(0x1337, []protocol.VersionNumber{config.Versions[2], config.Versions[1]})) + cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{config.Versions[2], config.Versions[1]})) Expect(cl.version).To(Equal(config.Versions[1])) }) @@ -428,14 +428,14 @@ var _ = Describe("Client", func() { // if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test cl.versionNegotiated = true Expect(sess.packetCount).To(BeZero()) - cl.handlePacket(nil, wire.ComposeVersionNegotiation(0x1337, []protocol.VersionNumber{1})) + cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{1})) Expect(cl.versionNegotiated).To(BeTrue()) Expect(sess.packetCount).To(BeZero()) }) It("drops version negotiation packets that contain the offered version", func() { ver := cl.version - cl.handlePacket(nil, wire.ComposeVersionNegotiation(0x1337, []protocol.VersionNumber{ver})) + cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(0x1337, []protocol.VersionNumber{ver})) Expect(cl.version).To(Equal(ver)) }) }) diff --git a/internal/wire/public_header_test.go b/internal/wire/public_header_test.go index 68dd81e4..8c5ca801 100644 --- a/internal/wire/public_header_test.go +++ b/internal/wire/public_header_test.go @@ -102,7 +102,7 @@ var _ = Describe("Public Header", func() { } It("parses version negotiation packets sent by the server", func() { - b := bytes.NewReader(ComposeVersionNegotiation(0x1337, protocol.SupportedVersions)) + b := bytes.NewReader(ComposeGQUICVersionNegotiation(0x1337, protocol.SupportedVersions)) hdr, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) Expect(err).ToNot(HaveOccurred()) Expect(hdr.VersionFlag).To(BeTrue()) @@ -131,7 +131,7 @@ var _ = Describe("Public Header", func() { }) It("errors on invalid version tags", func() { - data := ComposeVersionNegotiation(0x1337, protocol.SupportedVersions) + data := ComposeGQUICVersionNegotiation(0x1337, protocol.SupportedVersions) data = append(data, []byte{0x13, 0x37}...) b := bytes.NewReader(data) _, err := parsePublicHeader(b, protocol.PerspectiveServer, protocol.VersionUnknown) diff --git a/internal/wire/version_negotiation.go b/internal/wire/version_negotiation.go index 6879aed0..9357a9ce 100644 --- a/internal/wire/version_negotiation.go +++ b/internal/wire/version_negotiation.go @@ -7,18 +7,36 @@ import ( "github.com/lucas-clemente/quic-go/internal/utils" ) -// ComposeVersionNegotiation composes a Version Negotiation Packet -// TODO(894): implement the IETF draft format of Version Negotiation Packets -func ComposeVersionNegotiation(connectionID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { +// ComposeGQUICVersionNegotiation composes a Version Negotiation Packet for gQUIC +func ComposeGQUICVersionNegotiation(connID protocol.ConnectionID, versions []protocol.VersionNumber) []byte { fullReply := &bytes.Buffer{} ph := Header{ - ConnectionID: connectionID, + ConnectionID: connID, PacketNumber: 1, VersionFlag: true, } - err := ph.writePublicHeader(fullReply, protocol.PerspectiveServer, protocol.VersionWhatever) - if err != nil { + if err := ph.writePublicHeader(fullReply, protocol.PerspectiveServer, protocol.VersionWhatever); err != nil { utils.Errorf("error composing version negotiation packet: %s", err.Error()) + return nil + } + for _, v := range versions { + utils.BigEndian.WriteUint32(fullReply, uint32(v)) + } + return fullReply.Bytes() +} + +// ComposeVersionNegotiation composes a Version Negotiation according to the IETF draft +func ComposeVersionNegotiation(connID protocol.ConnectionID, pn protocol.PacketNumber, versions []protocol.VersionNumber) []byte { + fullReply := &bytes.Buffer{} + ph := Header{ + IsLongHeader: true, + Type: protocol.PacketTypeVersionNegotiation, + ConnectionID: connID, + PacketNumber: pn, + } + if err := ph.writeHeader(fullReply); err != nil { + utils.Errorf("error composing version negotiation packet: %s", err.Error()) + return nil } for _, v := range versions { utils.BigEndian.WriteUint32(fullReply, uint32(v)) diff --git a/internal/wire/version_negotiation_test.go b/internal/wire/version_negotiation_test.go index 23ac6bfd..869d008b 100644 --- a/internal/wire/version_negotiation_test.go +++ b/internal/wire/version_negotiation_test.go @@ -1,17 +1,32 @@ package wire import ( + "bytes" + "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -var _ = Describe("Version Negotiation Packet", func() { - It("composes version negotiation packets", func() { - expected := append( - []byte{0x01 | 0x08, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}, - []byte{'Q', '0', '3', '9'}..., - ) - Expect(ComposeVersionNegotiation(1, []protocol.VersionNumber{protocol.Version39})).To(Equal(expected)) +var _ = Describe("Version Negotiation Packets", func() { + It("writes for gQUIC", func() { + versions := []protocol.VersionNumber{1001, 1003} + data := ComposeGQUICVersionNegotiation(0x1337, versions) + hdr, err := parsePublicHeader(bytes.NewReader(data), protocol.PerspectiveServer, protocol.VersionUnknown) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.VersionFlag).To(BeTrue()) + Expect(hdr.ConnectionID).To(Equal(protocol.ConnectionID(0x1337))) + Expect(hdr.SupportedVersions).To(Equal(versions)) + }) + + It("writes IETF draft style", func() { + versions := []protocol.VersionNumber{1001, 1003} + data := ComposeVersionNegotiation(0x1337, 0x42, versions) + hdr, err := parseHeader(bytes.NewReader(data), protocol.PerspectiveServer) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.Type).To(Equal(protocol.PacketTypeVersionNegotiation)) + Expect(hdr.ConnectionID).To(Equal(protocol.ConnectionID(0x1337))) + Expect(hdr.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) + Expect(hdr.SupportedVersions).To(Equal(versions)) }) }) diff --git a/server.go b/server.go index f8ea5946..056dca72 100644 --- a/server.go +++ b/server.go @@ -266,15 +266,21 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet return nil } - // Send Version Negotiation Packet if the client is speaking a different protocol version + // send a Version Negotiation Packet if the client is speaking a different protocol version + // since the client send a Public Header (only gQUIC has a Version Flag), we need to send a gQUIC Version Negotiation Packet if hdr.VersionFlag && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { // drop packets that are too small to be valid first packets if len(packet) < protocol.ClientHelloMinimumSize+len(hdr.Raw) { return errors.New("dropping small packet with unknown version") } - // TODO(894): send a IETF draft style Version Negotiation Packets utils.Infof("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) - _, err = pconn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, s.config.Versions), remoteAddr) + if _, err := pconn.WriteTo(wire.ComposeGQUICVersionNegotiation(hdr.ConnectionID, s.config.Versions), remoteAddr); err != nil { + return err + } + } + // send an IETF draft style Version Negotiation Packet, if the client sent an unsupported version with an IETF draft style header + if hdr.Type == protocol.PacketTypeClientInitial && !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { + _, err := pconn.WriteTo(wire.ComposeVersionNegotiation(hdr.ConnectionID, hdr.PacketNumber, s.config.Versions), remoteAddr) return err } diff --git a/server_test.go b/server_test.go index f771f715..350c83c9 100644 --- a/server_test.go +++ b/server_test.go @@ -413,22 +413,55 @@ var _ = Describe("Server", func() { ln, err := Listen(conn, nil, config) Expect(err).ToNot(HaveOccurred()) - var returned bool + done := make(chan struct{}) go func() { ln.Accept() - returned = true + close(done) }() Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero()) Expect(conn.dataWrittenTo).To(Equal(udpAddr)) - b = &bytes.Buffer{} - utils.BigEndian.WriteUint32(b, uint32(99)) - expected := append( - []byte{0x9, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x13, 0x37}, - b.Bytes()..., - ) - Expect(conn.dataWritten.Bytes()).To(Equal(expected)) - Consistently(func() bool { return returned }).Should(BeFalse()) + r := bytes.NewReader(conn.dataWritten.Bytes()) + packet, err := wire.ParseHeader(r, protocol.PerspectiveServer, protocol.VersionUnknown) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.VersionFlag).To(BeTrue()) + Expect(packet.ConnectionID).To(Equal(protocol.ConnectionID(0x1337))) + Expect(r.Len()).To(BeZero()) + Consistently(done).ShouldNot(BeClosed()) + }) + + It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() { + config.Versions = []protocol.VersionNumber{99} + b := &bytes.Buffer{} + hdr := wire.Header{ + Type: protocol.PacketTypeClientInitial, + IsLongHeader: true, + ConnectionID: 0x1337, + PacketNumber: 0x55, + } + hdr.Write(b, protocol.PerspectiveClient, protocol.VersionTLS) + b.Write(bytes.Repeat([]byte{0}, protocol.ClientHelloMinimumSize)) // add a fake CHLO + conn.dataToRead = b.Bytes() + conn.dataReadFrom = udpAddr + ln, err := Listen(conn, nil, config) + Expect(err).ToNot(HaveOccurred()) + + done := make(chan struct{}) + go func() { + ln.Accept() + close(done) + }() + + Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero()) + Expect(conn.dataWrittenTo).To(Equal(udpAddr)) + r := bytes.NewReader(conn.dataWritten.Bytes()) + packet, err := wire.ParseHeader(r, protocol.PerspectiveServer, protocol.VersionUnknown) + Expect(err).ToNot(HaveOccurred()) + Expect(packet.Type).To(Equal(protocol.PacketTypeVersionNegotiation)) + Expect(packet.ConnectionID).To(Equal(protocol.ConnectionID(0x1337))) + Expect(packet.PacketNumber).To(Equal(protocol.PacketNumber(0x55))) + Expect(r.Len()).To(BeZero()) + Consistently(done).ShouldNot(BeClosed()) }) It("sends a PublicReset for new connections that don't have the VersionFlag set", func() {