diff --git a/internal/handshake/tls_extension.go b/internal/handshake/tls_extension.go index 6f49dcf2..9cfb665a 100644 --- a/internal/handshake/tls_extension.go +++ b/internal/handshake/tls_extension.go @@ -38,7 +38,7 @@ func (p *clientHelloTransportParameters) Unmarshal(data []byte) error { if len(data) != paramsLen { return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data)) } - return p.Parameters.unmarshal(data) + return p.Parameters.unmarshal(data, protocol.PerspectiveClient) } type encryptedExtensionsTransportParameters struct { @@ -86,5 +86,5 @@ func (p *encryptedExtensionsTransportParameters) Unmarshal(data []byte) error { if len(data) != paramsLen { return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data)) } - return p.Parameters.unmarshal(data) + return p.Parameters.unmarshal(data, protocol.PerspectiveServer) } diff --git a/internal/handshake/tls_extension_handler_server.go b/internal/handshake/tls_extension_handler_server.go index 6755d899..152e10f9 100644 --- a/internal/handshake/tls_extension_handler_server.go +++ b/internal/handshake/tls_extension_handler_server.go @@ -80,12 +80,6 @@ func (h *extensionHandlerServer) ReceivedExtensions(msgType uint8, exts []qtls.E if chtp.InitialVersion != h.version && protocol.IsSupportedVersion(h.supportedVersions, chtp.InitialVersion) { return qerr.Error(qerr.VersionNegotiationMismatch, "Client should have used the initial version") } - - // check that the client didn't send a stateless reset token - if len(chtp.Parameters.StatelessResetToken) != 0 { - // TODO: return the correct error type - return errors.New("client sent a stateless reset token") - } h.logger.Debugf("Received Transport Parameters: %s", &chtp.Parameters) h.paramsChan <- chtp.Parameters return nil diff --git a/internal/handshake/tls_extension_handler_server_test.go b/internal/handshake/tls_extension_handler_server_test.go index 01e8b72f..e64c40c2 100644 --- a/internal/handshake/tls_extension_handler_server_test.go +++ b/internal/handshake/tls_extension_handler_server_test.go @@ -1,7 +1,6 @@ package handshake import ( - "bytes" "time" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -99,13 +98,6 @@ var _ = Describe("TLS Extension Handler, for the server", func() { Expect(err).To(HaveOccurred()) // this will be some kind of decoding error }) - It("rejects messages that contain a stateless reset token", func() { - parameters.StatelessResetToken = bytes.Repeat([]byte{0}, 16) - ext := getClientHello(parameters) - err := handler.ReceivedExtensions(uint8(typeClientHello), []qtls.Extension{ext}) - Expect(err).To(MatchError("client sent a stateless reset token")) - }) - Context("Version Negotiation", func() { It("accepts a ClientHello, when no version negotiation was performed", func() { done := make(chan struct{}) diff --git a/internal/handshake/transport_parameter_test.go b/internal/handshake/transport_parameter_test.go index 9f0241ae..b6f71efc 100644 --- a/internal/handshake/transport_parameter_test.go +++ b/internal/handshake/transport_parameter_test.go @@ -48,7 +48,7 @@ var _ = Describe("Transport Parameters", func() { params.marshal(b) p := &TransportParameters{} - Expect(p.unmarshal(b.Bytes())).To(Succeed()) + Expect(p.unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(Succeed()) Expect(p.InitialMaxStreamDataBidiLocal).To(Equal(params.InitialMaxStreamDataBidiLocal)) Expect(p.InitialMaxStreamDataBidiRemote).To(Equal(params.InitialMaxStreamDataBidiRemote)) Expect(p.InitialMaxStreamDataUni).To(Equal(params.InitialMaxStreamDataUni)) @@ -65,7 +65,7 @@ var _ = Describe("Transport Parameters", func() { b := &bytes.Buffer{} params.marshal(b) p := &TransportParameters{} - Expect(p.unmarshal(b.Bytes())).To(MatchError("wrong length for stateless_reset_token: 15 (expected 16)")) + Expect(p.unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("wrong length for stateless_reset_token: 15 (expected 16)")) }) It("errors when the max_packet_size is too small", func() { @@ -74,7 +74,7 @@ var _ = Describe("Transport Parameters", func() { utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(1199))) utils.WriteVarInt(b, 1199) p := &TransportParameters{} - Expect(p.unmarshal(b.Bytes())).To(MatchError("invalid value for max_packet_size: 1199 (minimum 1200)")) + Expect(p.unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("invalid value for max_packet_size: 1199 (minimum 1200)")) }) It("errors when disable_migration has content", func() { @@ -83,7 +83,7 @@ var _ = Describe("Transport Parameters", func() { utils.BigEndian.WriteUint16(b, 6) b.Write([]byte("foobar")) p := &TransportParameters{} - Expect(p.unmarshal(b.Bytes())).To(MatchError("wrong length for disable_migration: 6 (expected empty)")) + Expect(p.unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("wrong length for disable_migration: 6 (expected empty)")) }) It("errors when the varint value has the wrong length", func() { @@ -94,7 +94,7 @@ var _ = Describe("Transport Parameters", func() { Expect(utils.VarIntLen(val)).ToNot(BeEquivalentTo(2)) utils.WriteVarInt(b, val) p := &TransportParameters{} - err := p.unmarshal(b.Bytes()) + err := p.unmarshal(b.Bytes(), protocol.PerspectiveServer) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("inconsistent transport parameter length")) }) @@ -114,7 +114,7 @@ var _ = Describe("Transport Parameters", func() { utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(0x42))) utils.WriteVarInt(b, 0x42) p := &TransportParameters{} - Expect(p.unmarshal(b.Bytes())).To(Succeed()) + Expect(p.unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(Succeed()) Expect(p.InitialMaxStreamDataBidiLocal).To(Equal(protocol.ByteCount(0x1337))) Expect(p.InitialMaxStreamDataBidiRemote).To(Equal(protocol.ByteCount(0x42))) }) @@ -134,7 +134,7 @@ var _ = Describe("Transport Parameters", func() { utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(0x1337))) utils.WriteVarInt(b, 0x1337) p := &TransportParameters{} - err := p.unmarshal(b.Bytes()) + err := p.unmarshal(b.Bytes(), protocol.PerspectiveServer) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("received duplicate transport parameter")) }) @@ -145,7 +145,7 @@ var _ = Describe("Transport Parameters", func() { utils.BigEndian.WriteUint16(b, 7) b.Write([]byte("foobar")) p := &TransportParameters{} - Expect(p.unmarshal(b.Bytes())).To(MatchError("remaining length (6) smaller than parameter length (7)")) + Expect(p.unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("remaining length (6) smaller than parameter length (7)")) }) It("errors if there's unprocessed data after reading", func() { @@ -155,6 +155,16 @@ var _ = Describe("Transport Parameters", func() { utils.WriteVarInt(b, 0x1337) b.Write([]byte("foo")) p := &TransportParameters{} - Expect(p.unmarshal(b.Bytes())).To(MatchError("should have read all data. Still have 3 bytes")) + Expect(p.unmarshal(b.Bytes(), protocol.PerspectiveServer)).To(MatchError("should have read all data. Still have 3 bytes")) + }) + + It("errors if the client sent a stateless_reset_token", func() { + params := &TransportParameters{ + StatelessResetToken: make([]byte, 16), + } + b := &bytes.Buffer{} + params.marshal(b) + p := &TransportParameters{} + Expect(p.unmarshal(b.Bytes(), protocol.PerspectiveClient)).To(MatchError("client sent a stateless_reset_token")) }) }) diff --git a/internal/handshake/transport_parameters.go b/internal/handshake/transport_parameters.go index ba9e68fd..d1936cd7 100644 --- a/internal/handshake/transport_parameters.go +++ b/internal/handshake/transport_parameters.go @@ -2,6 +2,7 @@ package handshake import ( "bytes" + "errors" "fmt" "io" "sort" @@ -43,7 +44,7 @@ type TransportParameters struct { StatelessResetToken []byte } -func (p *TransportParameters) unmarshal(data []byte) error { +func (p *TransportParameters) unmarshal(data []byte, sentBy protocol.Perspective) error { // needed to check that every parameter is only sent at most once var parameterIDs []transportParameterID @@ -76,6 +77,9 @@ func (p *TransportParameters) unmarshal(data []byte) error { } p.DisableMigration = true case statelessResetTokenParameterID: + if sentBy == protocol.PerspectiveClient { + return errors.New("client sent a stateless_reset_token") + } if paramLen != 16 { return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen) }