diff --git a/internal/handshake/transport_parameter_test.go b/internal/handshake/transport_parameter_test.go index 62b682e6..4697fcbe 100644 --- a/internal/handshake/transport_parameter_test.go +++ b/internal/handshake/transport_parameter_test.go @@ -2,6 +2,7 @@ package handshake import ( "bytes" + "fmt" "time" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -146,6 +147,7 @@ var _ = Describe("Transport Parameters", func() { statelessResetTokenParameterID: statelessResetToken, } }) + It("reads parameters", func() { err := params.unmarshal(marshal(parameters)) Expect(err).ToNot(HaveOccurred()) @@ -160,6 +162,16 @@ var _ = Describe("Transport Parameters", func() { Expect(params.StatelessResetToken).To(Equal(statelessResetToken)) }) + It("errors if a parameter is sent twice", func() { + data := marshal(parameters) + parameters = map[transportParameterID][]byte{ + maxPacketSizeParameterID: {0x73, 0x31}, + } + data = append(data, marshal(parameters)...) + err := params.unmarshal(data) + Expect(err).To(MatchError(fmt.Sprintf("received duplicate transport parameter %#x", maxPacketSizeParameterID))) + }) + It("doesn't allow values below the minimum remote idle timeout", func() { t := 2 * time.Second Expect(t).To(BeNumerically("<", protocol.MinRemoteIdleTimeout)) diff --git a/internal/handshake/transport_parameters.go b/internal/handshake/transport_parameters.go index 273a203b..c94e31f7 100644 --- a/internal/handshake/transport_parameters.go +++ b/internal/handshake/transport_parameters.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "fmt" + "sort" "time" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -96,14 +97,18 @@ func (p *TransportParameters) getHelloMap() map[Tag][]byte { } func (p *TransportParameters) unmarshal(data []byte) error { + // needed to check that every parameter is only sent at most once + var parameterIDs []transportParameterID + for len(data) >= 4 { - paramID := binary.BigEndian.Uint16(data[:2]) + paramID := transportParameterID(binary.BigEndian.Uint16(data[:2])) paramLen := int(binary.BigEndian.Uint16(data[2:4])) data = data[4:] if len(data) < paramLen { return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(data), paramLen) } - switch transportParameterID(paramID) { + parameterIDs = append(parameterIDs, paramID) + switch paramID { case initialMaxStreamDataParameterID: if paramLen != 4 { return fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", paramLen) @@ -152,6 +157,14 @@ func (p *TransportParameters) unmarshal(data []byte) error { data = data[paramLen:] } + // check that every transport parameter was sent at most once + sort.Slice(parameterIDs, func(i, j int) bool { return parameterIDs[i] < parameterIDs[j] }) + for i := 0; i < len(parameterIDs)-1; i++ { + if parameterIDs[i] == parameterIDs[i+1] { + return fmt.Errorf("received duplicate transport parameter %#x", parameterIDs[i]) + } + } + if len(data) != 0 { return fmt.Errorf("should have read all data. Still have %d bytes", len(data)) }