diff --git a/internal/handshake/transport_parameter_test.go b/internal/handshake/transport_parameter_test.go index 930e1b6d8..08fcc049f 100644 --- a/internal/handshake/transport_parameter_test.go +++ b/internal/handshake/transport_parameter_test.go @@ -53,13 +53,13 @@ var _ = Describe("Transport Parameters", func() { Expect(p.String()).To(Equal("&handshake.TransportParameters{OriginalConnectionID: 0xdeadbeef, InitialMaxStreamDataBidiLocal: 0x1234, InitialMaxStreamDataBidiRemote: 0x2345, InitialMaxStreamDataUni: 0x3456, InitialMaxData: 0x4567, MaxBidiStreams: 1337, MaxUniStreams: 7331, IdleTimeout: 42s, AckDelayExponent: 14, MaxAckDelay: 37s}")) }) - getRandomValue := func() uint64 { - maxVals := []int64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4} - rand.Seed(GinkgoRandomSeed()) - return uint64(rand.Int63n(maxVals[int(rand.Int31n(4))])) - } - It("marshals and unmarshals", func() { + rand.Seed(GinkgoRandomSeed()) + getRandomValue := func() uint64 { + maxVals := []int64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4} + return uint64(rand.Int63n(maxVals[int(rand.Int31n(4))])) + } + var token [16]byte rand.Read(token[:]) params := &TransportParameters{ diff --git a/internal/handshake/transport_parameters.go b/internal/handshake/transport_parameters.go index 70f748d2b..e8f97624a 100644 --- a/internal/handshake/transport_parameters.go +++ b/internal/handshake/transport_parameters.go @@ -205,51 +205,30 @@ func (p *TransportParameters) Marshal() []byte { b.Write([]byte{0, 0}) // length. Will be replaced later // initial_max_stream_data_bidi_local - utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataBidiLocalParameterID)) - utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxStreamDataBidiLocal)))) - utils.WriteVarInt(b, uint64(p.InitialMaxStreamDataBidiLocal)) + p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) // initial_max_stream_data_bidi_remote - utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataBidiRemoteParameterID)) - utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxStreamDataBidiRemote)))) - utils.WriteVarInt(b, uint64(p.InitialMaxStreamDataBidiRemote)) + p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) // initial_max_stream_data_uni - utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataUniParameterID)) - utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxStreamDataUni)))) - utils.WriteVarInt(b, uint64(p.InitialMaxStreamDataUni)) + p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) // initial_max_data - utils.BigEndian.WriteUint16(b, uint16(initialMaxDataParameterID)) - utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxData)))) - utils.WriteVarInt(b, uint64(p.InitialMaxData)) + p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) // initial_max_bidi_streams - utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamsBidiParameterID)) - utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(p.MaxBidiStreams))) - utils.WriteVarInt(b, p.MaxBidiStreams) + p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, p.MaxBidiStreams) // initial_max_uni_streams - utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamsUniParameterID)) - utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(p.MaxUniStreams))) - utils.WriteVarInt(b, p.MaxUniStreams) + p.marshalVarintParam(b, initialMaxStreamsUniParameterID, p.MaxUniStreams) // idle_timeout - idleTimeout := uint64(p.IdleTimeout / time.Millisecond) - utils.BigEndian.WriteUint16(b, uint16(idleTimeoutParameterID)) - utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(idleTimeout))) - utils.WriteVarInt(b, idleTimeout) + p.marshalVarintParam(b, idleTimeoutParameterID, uint64(p.IdleTimeout/time.Millisecond)) // max_packet_size - utils.BigEndian.WriteUint16(b, uint16(maxPacketSizeParameterID)) - utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(protocol.MaxReceivePacketSize)))) - utils.WriteVarInt(b, uint64(protocol.MaxReceivePacketSize)) + p.marshalVarintParam(b, maxPacketSizeParameterID, uint64(protocol.MaxReceivePacketSize)) // max_ack_delay // Only send it if is different from the default value. if p.MaxAckDelay != protocol.DefaultMaxAckDelay { - utils.BigEndian.WriteUint16(b, uint16(maxAckDelayParameterID)) - utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.MaxAckDelay/time.Millisecond)))) - utils.WriteVarInt(b, uint64(p.MaxAckDelay/time.Millisecond)) + p.marshalVarintParam(b, maxAckDelayParameterID, uint64(p.MaxAckDelay/time.Millisecond)) } // ack_delay_exponent // Only send it if is different from the default value. if p.AckDelayExponent != protocol.DefaultAckDelayExponent { - utils.BigEndian.WriteUint16(b, uint16(ackDelayExponentParameterID)) - utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.AckDelayExponent)))) - utils.WriteVarInt(b, uint64(p.AckDelayExponent)) + p.marshalVarintParam(b, ackDelayExponentParameterID, uint64(p.AckDelayExponent)) } // disable_migration if p.DisableMigration { @@ -273,6 +252,12 @@ func (p *TransportParameters) Marshal() []byte { return data } +func (p *TransportParameters) marshalVarintParam(b *bytes.Buffer, id transportParameterID, val uint64) { + utils.BigEndian.WriteUint16(b, uint16(id)) + utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(val))) + utils.WriteVarInt(b, val) +} + // String returns a string representation, intended for logging. func (p *TransportParameters) String() string { logString := "&handshake.TransportParameters{OriginalConnectionID: %s, InitialMaxStreamDataBidiLocal: %#x, InitialMaxStreamDataBidiRemote: %#x, InitialMaxStreamDataUni: %#x, InitialMaxData: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s, AckDelayExponent: %d, MaxAckDelay: %s"