diff --git a/handshake/connection_parameters_manager.go b/handshake/connection_parameters_manager.go index c36890f56..f654b06bc 100644 --- a/handshake/connection_parameters_manager.go +++ b/handshake/connection_parameters_manager.go @@ -18,6 +18,7 @@ type ConnectionParametersManager struct { params map[Tag][]byte mutex sync.RWMutex + maxStreamsPerConnection uint32 idleConnectionStateLifetime time.Duration sendStreamFlowControlWindow protocol.ByteCount sendConnectionFlowControlWindow protocol.ByteCount @@ -31,9 +32,7 @@ var ErrTagNotInConnectionParameterMap = errors.New("Tag not found in Connections // NewConnectionParamatersManager creates a new connection parameters manager func NewConnectionParamatersManager() *ConnectionParametersManager { return &ConnectionParametersManager{ - params: map[Tag][]byte{ - TagMSPC: {0x64, 0x00, 0x00, 0x00}, // Max streams per connection = 100 - }, + params: make(map[Tag][]byte), idleConnectionStateLifetime: protocol.InitialIdleConnectionStateLifetime, sendStreamFlowControlWindow: protocol.InitialStreamFlowControlWindow, // can only be changed by the client sendConnectionFlowControlWindow: protocol.InitialConnectionFlowControlWindow, // can only be changed by the client @@ -49,8 +48,14 @@ func (h *ConnectionParametersManager) SetFromMap(params map[Tag][]byte) error { for key, value := range params { switch key { - case TagMSPC, TagTCID: + case TagTCID: h.params[key] = value + case TagMSPC: + clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) + if err != nil { + return err + } + h.maxStreamsPerConnection = h.negotiateMaxStreamsPerConnection(clientValue) case TagICSL: clientValue, err := utils.ReadUint32(bytes.NewBuffer(value)) if err != nil { @@ -75,6 +80,10 @@ func (h *ConnectionParametersManager) SetFromMap(params map[Tag][]byte) error { return nil } +func (h *ConnectionParametersManager) negotiateMaxStreamsPerConnection(clientValue uint32) uint32 { + return utils.MinUint32(clientValue, protocol.MaxStreamsPerConnection) +} + func (h *ConnectionParametersManager) negotiateIdleConnectionStateLifetime(clientValue time.Duration) time.Duration { // TODO: what happens if the clients sets 0 seconds? return utils.MinDuration(clientValue, protocol.MaxIdleConnectionStateLifetime) @@ -98,13 +107,14 @@ func (h *ConnectionParametersManager) GetSHLOMap() map[Tag][]byte { utils.WriteUint32(sfcw, uint32(h.GetReceiveStreamFlowControlWindow())) cfcw := bytes.NewBuffer([]byte{}) utils.WriteUint32(cfcw, uint32(h.GetReceiveConnectionFlowControlWindow())) + mspc := bytes.NewBuffer([]byte{}) + utils.WriteUint32(mspc, uint32(h.GetMaxStreamsPerConnection())) icsl := bytes.NewBuffer([]byte{}) - utils.Debugf("ICSL: %#v\n", h.GetIdleConnectionStateLifetime()) utils.WriteUint32(icsl, uint32(h.GetIdleConnectionStateLifetime()/time.Second)) return map[Tag][]byte{ TagICSL: icsl.Bytes(), - TagMSPC: []byte{0x64, 0x00, 0x00, 0x00}, //100 + TagMSPC: mspc.Bytes(), TagCFCW: cfcw.Bytes(), TagSFCW: sfcw.Bytes(), } @@ -142,6 +152,14 @@ func (h *ConnectionParametersManager) GetReceiveConnectionFlowControlWindow() pr return h.receiveConnectionFlowControlWindow } +// GetMaxStreamsPerConnection gets the maximum number of streams per connection +func (h *ConnectionParametersManager) GetMaxStreamsPerConnection() uint32 { + h.mutex.RLock() + defer h.mutex.RUnlock() + + return h.maxStreamsPerConnection +} + // GetIdleConnectionStateLifetime gets the idle timeout func (h *ConnectionParametersManager) GetIdleConnectionStateLifetime() time.Duration { h.mutex.RLock() diff --git a/handshake/connection_parameters_manager_test.go b/handshake/connection_parameters_manager_test.go index 81891b78a..affc91b8c 100644 --- a/handshake/connection_parameters_manager_test.go +++ b/handshake/connection_parameters_manager_test.go @@ -15,16 +15,16 @@ var _ = Describe("ConnectionsParameterManager", func() { }) It("stores and retrieves a value", func() { - mspc := []byte{0x13, 0x37} + tcid := []byte{0x13, 0x37} values := map[Tag][]byte{ - TagMSPC: mspc, + TagTCID: tcid, } cpm.SetFromMap(values) - val, err := cpm.getRawValue(TagMSPC) + val, err := cpm.getRawValue(TagTCID) Expect(err).ToNot(HaveOccurred()) - Expect(val).To(Equal(mspc)) + Expect(val).To(Equal(tcid)) }) It("returns an error for a tag that is not set", func() { @@ -40,26 +40,33 @@ var _ = Describe("ConnectionsParameterManager", func() { Expect(entryMap).To(HaveKey(TagMSPC)) }) - It("returns stream-level flow control windows in SHLO", func() { + It("sets the stream-level flow control windows in SHLO", func() { cpm.receiveStreamFlowControlWindow = 0xDEADBEEF entryMap := cpm.GetSHLOMap() Expect(entryMap).To(HaveKey(TagSFCW)) Expect(entryMap[TagSFCW]).To(Equal([]byte{0xEF, 0xBE, 0xAD, 0xDE})) }) - It("returns connection-level flow control windows in SHLO", func() { + It("sets the connection-level flow control windows in SHLO", func() { cpm.receiveConnectionFlowControlWindow = 0xDECAFBAD entryMap := cpm.GetSHLOMap() Expect(entryMap).To(HaveKey(TagCFCW)) Expect(entryMap[TagCFCW]).To(Equal([]byte{0xAD, 0xFB, 0xCA, 0xDE})) }) - It("returns connection-level flow control windows in SHLO", func() { + It("sets the connection-level flow control windows in SHLO", func() { cpm.idleConnectionStateLifetime = 0xDECAFBAD * time.Second entryMap := cpm.GetSHLOMap() Expect(entryMap).To(HaveKey(TagICSL)) Expect(entryMap[TagICSL]).To(Equal([]byte{0xAD, 0xFB, 0xCA, 0xDE})) }) + + It("sets the maximum streams per connection in SHLO", func() { + cpm.maxStreamsPerConnection = 0xDEADBEEF + entryMap := cpm.GetSHLOMap() + Expect(entryMap).To(HaveKey(TagMSPC)) + Expect(entryMap[TagMSPC]).To(Equal([]byte{0xEF, 0xBE, 0xAD, 0xDE})) + }) }) Context("Truncated connection IDs", func() { @@ -159,4 +166,30 @@ var _ = Describe("ConnectionsParameterManager", func() { Expect(cpm.GetIdleConnectionStateLifetime()).To(Equal(value)) }) }) + + Context("max streams per connection", func() { + It("negotiates correctly when the client wants a larger number", func() { + Expect(cpm.negotiateMaxStreamsPerConnection(protocol.MaxStreamsPerConnection + 10)).To(Equal(protocol.MaxStreamsPerConnection)) + }) + + It("negotiates correctly when the client wants a smaller number", func() { + Expect(cpm.negotiateMaxStreamsPerConnection(protocol.MaxStreamsPerConnection - 1)).To(Equal(protocol.MaxStreamsPerConnection - 1)) + }) + + It("sets the negotiated max streams per connection value", func() { + // this test only works if the value given here is smaller than protocol.MaxStreamsPerConnection + values := map[Tag][]byte{ + TagMSPC: []byte{2, 0, 0, 0}, + } + err := cpm.SetFromMap(values) + Expect(err).ToNot(HaveOccurred()) + Expect(cpm.GetMaxStreamsPerConnection()).To(Equal(uint32(2))) + }) + + It("gets the max streams per connection value", func() { + var value uint32 = 0xDECAFBAD + cpm.maxStreamsPerConnection = value + Expect(cpm.GetMaxStreamsPerConnection()).To(Equal(value)) + }) + }) }) diff --git a/protocol/server_parameters.go b/protocol/server_parameters.go index 752c334fe..6ed22cda7 100644 --- a/protocol/server_parameters.go +++ b/protocol/server_parameters.go @@ -32,6 +32,10 @@ const ReceiveStreamFlowControlWindow ByteCount = (1 << 20) // 1 MB // TODO: set a reasonable value here const ReceiveConnectionFlowControlWindow ByteCount = (1 << 20) // 1 MB -// MaxIdleConnectionStateLifetime is the maximum value we accept for the idle connection state lifetime +// MaxStreamsPerConnection is the maximum value accepted for the number of streams per connection +// TODO: set a reasonable value here +const MaxStreamsPerConnection uint32 = 100 + +// MaxIdleConnectionStateLifetime is the maximum value accepted for the idle connection state lifetime // TODO: set a reasonable value here const MaxIdleConnectionStateLifetime = 60 * time.Second diff --git a/utils/minmax.go b/utils/minmax.go index 1fd514cb2..370cb1929 100644 --- a/utils/minmax.go +++ b/utils/minmax.go @@ -34,6 +34,14 @@ func Min(a, b int) int { return b } +// MinUint32 returns the maximum of two uint32 +func MinUint32(a, b uint32) uint32 { + if a < b { + return a + } + return b +} + // MinInt64 returns the minimum of two int64 func MinInt64(a, b int64) int64 { if a < b { @@ -58,6 +66,7 @@ func MaxDuration(a, b time.Duration) time.Duration { return b } +// MinDuration returns the minimum duration func MinDuration(a, b time.Duration) time.Duration { if a > b { return b diff --git a/utils/minmax_test.go b/utils/minmax_test.go index e09a84934..bbd3225cd 100644 --- a/utils/minmax_test.go +++ b/utils/minmax_test.go @@ -46,6 +46,11 @@ var _ = Describe("Min / Max", func() { Expect(Min(7, 5)).To(Equal(5)) }) + It("returns the minimum uint32", func() { + Expect(MinUint32(7, 5)).To(Equal(uint32(5))) + Expect(MinUint32(5, 7)).To(Equal(uint32(5))) + }) + It("returns the minimum int64", func() { Expect(MinInt64(7, 5)).To(Equal(int64(5))) Expect(MinInt64(5, 7)).To(Equal(int64(5)))