diff --git a/internal/handshake/transport_parameter_test.go b/internal/handshake/transport_parameter_test.go index 5168b150..bb377624 100644 --- a/internal/handshake/transport_parameter_test.go +++ b/internal/handshake/transport_parameter_test.go @@ -21,6 +21,15 @@ var _ = Describe("Transport Parameters", func() { return append(data, tp...) } + getRandomValue := func() uint64 { + maxVals := []int64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4} + return uint64(rand.Int63n(maxVals[int(rand.Int31n(4))])) + } + + BeforeEach(func() { + rand.Seed(GinkgoRandomSeed()) + }) + It("has a string representation", func() { p := &TransportParameters{ InitialMaxStreamDataBidiLocal: 0x1234, @@ -57,12 +66,6 @@ var _ = Describe("Transport Parameters", func() { }) It("marshals and unmarshals", func() { - rand.Seed(GinkgoRandomSeed()) - getRandomValue := func() uint64 { - maxVals := []int64{math.MaxUint8 / 4, math.MaxUint16 / 4, math.MaxUint32 / 4, math.MaxUint64 / 4} - return uint64(rand.Int63n(maxVals[int(rand.Int31n(4))])) - } - var token [16]byte rand.Read(token[:]) params := &TransportParameters{ @@ -327,4 +330,78 @@ var _ = Describe("Transport Parameters", func() { } }) }) + + Context("saving and retrieving from a session ticket", func() { + It("saves and retrieves the parameters", func() { + params := &TransportParameters{ + InitialMaxStreamDataBidiLocal: protocol.ByteCount(getRandomValue()), + InitialMaxStreamDataBidiRemote: protocol.ByteCount(getRandomValue()), + InitialMaxStreamDataUni: protocol.ByteCount(getRandomValue()), + InitialMaxData: protocol.ByteCount(getRandomValue()), + MaxBidiStreamNum: protocol.StreamNum(getRandomValue()), + MaxUniStreamNum: protocol.StreamNum(getRandomValue()), + } + data := params.MarshalForSessionTicket() + Expect(params.ValidFromSessionTicket(data)).To(BeTrue()) + }) + + It("rejects the parameters if it can't parse them", func() { + params := &TransportParameters{} + Expect(params.ValidFromSessionTicket([]byte("foobar"))).To(BeFalse()) + }) + + Context("rejects the parameters if they changed", func() { + var p *TransportParameters + params := &TransportParameters{ + InitialMaxStreamDataBidiLocal: 1, + InitialMaxStreamDataBidiRemote: 2, + InitialMaxStreamDataUni: 3, + InitialMaxData: 4, + MaxBidiStreamNum: 5, + MaxUniStreamNum: 6, + } + + BeforeEach(func() { + p = &TransportParameters{ + InitialMaxStreamDataBidiLocal: 1, + InitialMaxStreamDataBidiRemote: 2, + InitialMaxStreamDataUni: 3, + InitialMaxData: 4, + MaxBidiStreamNum: 5, + MaxUniStreamNum: 6, + } + Expect(params.ValidFromSessionTicket(p.Marshal())).To(BeTrue()) + }) + + It("rejects the parameters if the InitialMaxStreamDataBidiLocal changed", func() { + p.InitialMaxStreamDataBidiLocal = 0 + Expect(params.ValidFromSessionTicket(p.Marshal())).To(BeFalse()) + }) + + It("rejects the parameters if the InitialMaxStreamDataBidiRemote changed", func() { + p.InitialMaxStreamDataBidiRemote = 0 + Expect(params.ValidFromSessionTicket(p.Marshal())).To(BeFalse()) + }) + + It("rejects the parameters if the InitialMaxStreamDataUni changed", func() { + p.InitialMaxStreamDataUni = 0 + Expect(params.ValidFromSessionTicket(p.Marshal())).To(BeFalse()) + }) + + It("rejects the parameters if the InitialMaxData changed", func() { + p.InitialMaxData = 0 + Expect(params.ValidFromSessionTicket(p.Marshal())).To(BeFalse()) + }) + + It("rejects the parameters if the MaxBidiStreamNum changed", func() { + p.MaxBidiStreamNum = 0 + Expect(params.ValidFromSessionTicket(p.Marshal())).To(BeFalse()) + }) + + It("rejects the parameters if the MaxUniStreamNum changed", func() { + p.MaxUniStreamNum = 0 + Expect(params.ValidFromSessionTicket(p.Marshal())).To(BeFalse()) + }) + }) + }) }) diff --git a/internal/handshake/transport_parameters.go b/internal/handshake/transport_parameters.go index 0d1337ca..47cea6ee 100644 --- a/internal/handshake/transport_parameters.go +++ b/internal/handshake/transport_parameters.go @@ -369,6 +369,51 @@ func (p *TransportParameters) marshalVarintParam(b *bytes.Buffer, id transportPa utils.WriteVarInt(b, val) } +// MarshalForSessionTicket marshals the transport parameters we save in the session ticket. +// When sending a 0-RTT enabled TLS session tickets, we need to save the transport parameters. +// The client will remember the transport parameters used in the last session, +// and apply those to the 0-RTT data it sends. +// Saving the transport parameters in the ticket gives the server the option to reject 0-RTT +// if the transport parameters changed. +// Since the session ticket is encrypted, the serialization format is defined by the server. +// For convenience, we use the same format that we also use for sending the transport parameters. +func (p *TransportParameters) MarshalForSessionTicket() []byte { + b := &bytes.Buffer{} + b.Write([]byte{0, 0}) // length. Will be replaced later + + // initial_max_stream_data_bidi_local + p.marshalVarintParam(b, initialMaxStreamDataBidiLocalParameterID, uint64(p.InitialMaxStreamDataBidiLocal)) + // initial_max_stream_data_bidi_remote + p.marshalVarintParam(b, initialMaxStreamDataBidiRemoteParameterID, uint64(p.InitialMaxStreamDataBidiRemote)) + // initial_max_stream_data_uni + p.marshalVarintParam(b, initialMaxStreamDataUniParameterID, uint64(p.InitialMaxStreamDataUni)) + // initial_max_data + p.marshalVarintParam(b, initialMaxDataParameterID, uint64(p.InitialMaxData)) + // initial_max_bidi_streams + p.marshalVarintParam(b, initialMaxStreamsBidiParameterID, uint64(p.MaxBidiStreamNum)) + // initial_max_uni_streams + p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum)) + + data := b.Bytes() + binary.BigEndian.PutUint16(data[:2], uint16(b.Len()-2)) + return data +} + +// ValidFromSessionTicket checks if the transport parameters match those saved in the session ticket. +func (p *TransportParameters) ValidFromSessionTicket(data []byte) bool { + tp := &TransportParameters{} + if err := tp.Unmarshal(data, protocol.PerspectiveServer); err != nil { + return false + } + + return p.InitialMaxStreamDataBidiLocal == tp.InitialMaxStreamDataBidiLocal && + p.InitialMaxStreamDataBidiRemote == tp.InitialMaxStreamDataBidiRemote && + p.InitialMaxStreamDataUni == tp.InitialMaxStreamDataUni && + p.InitialMaxData == tp.InitialMaxData && + p.MaxBidiStreamNum == tp.MaxBidiStreamNum && + p.MaxUniStreamNum == tp.MaxUniStreamNum +} + // String returns a string representation, intended for logging. func (p *TransportParameters) String() string { logString := "&handshake.TransportParameters{OriginalConnectionID: %s, InitialMaxStreamDataBidiLocal: %#x, InitialMaxStreamDataBidiRemote: %#x, InitialMaxStreamDataUni: %#x, InitialMaxData: %#x, MaxBidiStreamNum: %d, MaxUniStreamNum: %d, MaxIdleTimeout: %s, AckDelayExponent: %d, MaxAckDelay: %s, ActiveConnectionIDLimit: %d"