From d07baef91b3f8a43ac71089b62f841c8684b9097 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 4 May 2016 17:20:36 +0700 Subject: [PATCH] manage connection parameters in a separate class --- handshake/connection_parameters_manager.go | 53 +++++++++++++++++++ .../connection_parameters_manager_test.go | 45 ++++++++++++++++ handshake/crypto_setup.go | 37 ++++++++----- handshake/crypto_setup_test.go | 4 +- session.go | 25 +++++---- 5 files changed, 138 insertions(+), 26 deletions(-) create mode 100644 handshake/connection_parameters_manager.go create mode 100644 handshake/connection_parameters_manager_test.go diff --git a/handshake/connection_parameters_manager.go b/handshake/connection_parameters_manager.go new file mode 100644 index 00000000..5161c74f --- /dev/null +++ b/handshake/connection_parameters_manager.go @@ -0,0 +1,53 @@ +package handshake + +import ( + "errors" + "sync" +) + +// ConnectionParametersManager stores the connection parameters +type ConnectionParametersManager struct { + params map[Tag][]byte + mutex sync.RWMutex +} + +// ErrTagNotInConnectionParameterMap is returned when a tag is not present in the connection parameters +var ErrTagNotInConnectionParameterMap = errors.New("Tag not found in ConnectionsParameter map") + +// NewConnectionParamatersManager creates a new connection parameters manager +func NewConnectionParamatersManager() *ConnectionParametersManager { + cpm := &ConnectionParametersManager{ + params: make(map[Tag][]byte), + } + return cpm +} + +// SetFromMap reads all params +func (h *ConnectionParametersManager) SetFromMap(params map[Tag][]byte) error { + h.mutex.Lock() + for key, value := range params { + h.params[key] = value + } + h.mutex.Unlock() + return nil +} + +// GetRawValue gets the byte-slice for a tag +func (h *ConnectionParametersManager) GetRawValue(tag Tag) ([]byte, error) { + h.mutex.RLock() + rawValue, ok := h.params[tag] + h.mutex.RUnlock() + + if !ok { + return nil, ErrTagNotInConnectionParameterMap + } + return rawValue, nil +} + +// GetSHLOMap gets all values (except crypto values) needed for the SHLO +func (h *ConnectionParametersManager) GetSHLOMap() map[Tag][]byte { + return map[Tag][]byte{ + TagICSL: []byte{0x1e, 0x00, 0x00, 0x00}, //30 + TagMSPC: []byte{0x64, 0x00, 0x00, 0x00}, //100 + } +} diff --git a/handshake/connection_parameters_manager_test.go b/handshake/connection_parameters_manager_test.go new file mode 100644 index 00000000..7d71196a --- /dev/null +++ b/handshake/connection_parameters_manager_test.go @@ -0,0 +1,45 @@ +package handshake + +import ( + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("ConnectionsParameterManager", func() { + var cpm *ConnectionParametersManager + BeforeEach(func() { + cpm = NewConnectionParamatersManager() + }) + + It("stores and retrieves a value", func() { + kexs := []byte{0xDE, 0xCA, 0xFB, 0xAD} + icsl := []byte{0x13, 0x37} + values := map[Tag][]byte{ + TagKEXS: kexs, + TagICSL: icsl, + } + + cpm.SetFromMap(values) + + val, err := cpm.GetRawValue(TagKEXS) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(kexs)) + + val, err = cpm.GetRawValue(TagICSL) + Expect(err).ToNot(HaveOccurred()) + Expect(val).To(Equal(icsl)) + }) + + It("returns an error for a tag that is not set", func() { + _, err := cpm.GetRawValue(TagKEXS) + Expect(err).To(HaveOccurred()) + Expect(err).To(Equal(ErrTagNotInConnectionParameterMap)) + }) + + It("returns all parameters necessary for the SHLO", func() { + entryMap := cpm.GetSHLOMap() + Expect(entryMap).To(HaveKey(TagICSL)) + Expect(entryMap).To(HaveKey(TagMSPC)) + }) + +}) diff --git a/handshake/crypto_setup.go b/handshake/crypto_setup.go index b84bcb35..3b3fe6c7 100644 --- a/handshake/crypto_setup.go +++ b/handshake/crypto_setup.go @@ -31,24 +31,27 @@ type CryptoSetup struct { cryptoStream utils.Stream + connectionParametersManager *ConnectionParametersManager + mutex sync.RWMutex } var _ crypto.AEAD = &CryptoSetup{} // NewCryptoSetup creates a new CryptoSetup instance -func NewCryptoSetup(connID protocol.ConnectionID, version protocol.VersionNumber, scfg *ServerConfig, cryptoStream utils.Stream) *CryptoSetup { +func NewCryptoSetup(connID protocol.ConnectionID, version protocol.VersionNumber, scfg *ServerConfig, cryptoStream utils.Stream, connectionParametersManager *ConnectionParametersManager) *CryptoSetup { nonce := make([]byte, 32) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { panic(err) } return &CryptoSetup{ - connID: connID, - version: version, - scfg: scfg, - nonce: nonce, - keyDerivation: crypto.DeriveKeysChacha20, - cryptoStream: cryptoStream, + connID: connID, + version: version, + scfg: scfg, + nonce: nonce, + keyDerivation: crypto.DeriveKeysChacha20, + cryptoStream: cryptoStream, + connectionParametersManager: connectionParametersManager, } } @@ -184,13 +187,19 @@ func (h *CryptoSetup) handleCHLO(data []byte, cryptoData map[Tag][]byte) ([]byte return nil, err } + err = h.connectionParametersManager.SetFromMap(cryptoData) + if err != nil { + return nil, err + } + + replyMap := h.connectionParametersManager.GetSHLOMap() + // add crypto parameters + replyMap[TagPUBS] = h.scfg.kex.PublicKey() + replyMap[TagSNO] = h.nonce + replyMap[TagVER] = protocol.SupportedVersionsAsTags + var reply bytes.Buffer - WriteHandshakeMessage(&reply, TagSHLO, map[Tag][]byte{ - TagPUBS: h.scfg.kex.PublicKey(), - TagSNO: h.nonce, - TagVER: protocol.SupportedVersionsAsTags, - TagICSL: []byte{0x1e, 0x00, 0x00, 0x00}, //30 - TagMSPC: []byte{0x64, 0x00, 0x00, 0x00}, //100 - }) + WriteHandshakeMessage(&reply, TagSHLO, replyMap) + return reply.Bytes(), nil } diff --git a/handshake/crypto_setup_test.go b/handshake/crypto_setup_test.go index 3ac9da9d..35d3fc57 100644 --- a/handshake/crypto_setup_test.go +++ b/handshake/crypto_setup_test.go @@ -89,6 +89,7 @@ var _ = Describe("Crypto setup", func() { scfg *ServerConfig cs *CryptoSetup stream *mockStream + cpm *ConnectionParametersManager ) BeforeEach(func() { @@ -97,7 +98,8 @@ var _ = Describe("Crypto setup", func() { signer = &mockSigner{} scfg = NewServerConfig(kex, signer) v := protocol.SupportedVersions[len(protocol.SupportedVersions)-1] - cs = NewCryptoSetup(protocol.ConnectionID(42), v, scfg, stream) + cpm = NewConnectionParamatersManager() + cs = NewCryptoSetup(protocol.ConnectionID(42), v, scfg, stream, cpm) cs.keyDerivation = mockKeyDerivation }) diff --git a/session.go b/session.go index a4b0dea5..3fbcdc2a 100644 --- a/session.go +++ b/session.go @@ -52,6 +52,8 @@ type Session struct { closeChan chan struct{} closed bool + connectionParametersManager *handshake.ConnectionParametersManager + // Used to calculate the next packet number from the truncated wire // representation, and sent back in public reset packets lastRcvdPacketNumber protocol.PacketNumber @@ -63,20 +65,21 @@ type Session struct { func NewSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback) PacketHandler { stopWaitingManager := ackhandler.NewStopWaitingManager() session := &Session{ - connectionID: connectionID, - conn: conn, - streamCallback: streamCallback, - streams: make(map[protocol.StreamID]*stream), - sentPacketHandler: ackhandler.NewSentPacketHandler(stopWaitingManager), - receivedPacketHandler: ackhandler.NewReceivedPacketHandler(), - stopWaitingManager: stopWaitingManager, - receivedPackets: make(chan receivedPacket, 1000), // TODO: What if server receives many packets and connection is already closed?! - closeChan: make(chan struct{}, 1), - rttStats: congestion.RTTStats{}, + connectionID: connectionID, + conn: conn, + streamCallback: streamCallback, + streams: make(map[protocol.StreamID]*stream), + sentPacketHandler: ackhandler.NewSentPacketHandler(stopWaitingManager), + receivedPacketHandler: ackhandler.NewReceivedPacketHandler(), + stopWaitingManager: stopWaitingManager, + receivedPackets: make(chan receivedPacket, 1000), // TODO: What if server receives many packets and connection is already closed?! + closeChan: make(chan struct{}, 1), + rttStats: congestion.RTTStats{}, + connectionParametersManager: handshake.NewConnectionParamatersManager(), } cryptoStream, _ := session.NewStream(1) - cryptoSetup := handshake.NewCryptoSetup(connectionID, v, sCfg, cryptoStream) + cryptoSetup := handshake.NewCryptoSetup(connectionID, v, sCfg, cryptoStream, session.connectionParametersManager) go func() { if err := cryptoSetup.HandleCryptoStream(); err != nil {