forked from quic-go/quic-go
manage connection parameters in a separate class
This commit is contained in:
53
handshake/connection_parameters_manager.go
Normal file
53
handshake/connection_parameters_manager.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ConnectionParametersManager stores the connection parameters
|
||||
type ConnectionParametersManager struct {
|
||||
params map[Tag][]byte
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// ErrTagNotInConnectionParameterMap is returned when a tag is not present in the connection parameters
|
||||
var ErrTagNotInConnectionParameterMap = errors.New("Tag not found in ConnectionsParameter map")
|
||||
|
||||
// NewConnectionParamatersManager creates a new connection parameters manager
|
||||
func NewConnectionParamatersManager() *ConnectionParametersManager {
|
||||
cpm := &ConnectionParametersManager{
|
||||
params: make(map[Tag][]byte),
|
||||
}
|
||||
return cpm
|
||||
}
|
||||
|
||||
// SetFromMap reads all params
|
||||
func (h *ConnectionParametersManager) SetFromMap(params map[Tag][]byte) error {
|
||||
h.mutex.Lock()
|
||||
for key, value := range params {
|
||||
h.params[key] = value
|
||||
}
|
||||
h.mutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRawValue gets the byte-slice for a tag
|
||||
func (h *ConnectionParametersManager) GetRawValue(tag Tag) ([]byte, error) {
|
||||
h.mutex.RLock()
|
||||
rawValue, ok := h.params[tag]
|
||||
h.mutex.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, ErrTagNotInConnectionParameterMap
|
||||
}
|
||||
return rawValue, nil
|
||||
}
|
||||
|
||||
// GetSHLOMap gets all values (except crypto values) needed for the SHLO
|
||||
func (h *ConnectionParametersManager) GetSHLOMap() map[Tag][]byte {
|
||||
return map[Tag][]byte{
|
||||
TagICSL: []byte{0x1e, 0x00, 0x00, 0x00}, //30
|
||||
TagMSPC: []byte{0x64, 0x00, 0x00, 0x00}, //100
|
||||
}
|
||||
}
|
||||
45
handshake/connection_parameters_manager_test.go
Normal file
45
handshake/connection_parameters_manager_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ConnectionsParameterManager", func() {
|
||||
var cpm *ConnectionParametersManager
|
||||
BeforeEach(func() {
|
||||
cpm = NewConnectionParamatersManager()
|
||||
})
|
||||
|
||||
It("stores and retrieves a value", func() {
|
||||
kexs := []byte{0xDE, 0xCA, 0xFB, 0xAD}
|
||||
icsl := []byte{0x13, 0x37}
|
||||
values := map[Tag][]byte{
|
||||
TagKEXS: kexs,
|
||||
TagICSL: icsl,
|
||||
}
|
||||
|
||||
cpm.SetFromMap(values)
|
||||
|
||||
val, err := cpm.GetRawValue(TagKEXS)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(val).To(Equal(kexs))
|
||||
|
||||
val, err = cpm.GetRawValue(TagICSL)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(val).To(Equal(icsl))
|
||||
})
|
||||
|
||||
It("returns an error for a tag that is not set", func() {
|
||||
_, err := cpm.GetRawValue(TagKEXS)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err).To(Equal(ErrTagNotInConnectionParameterMap))
|
||||
})
|
||||
|
||||
It("returns all parameters necessary for the SHLO", func() {
|
||||
entryMap := cpm.GetSHLOMap()
|
||||
Expect(entryMap).To(HaveKey(TagICSL))
|
||||
Expect(entryMap).To(HaveKey(TagMSPC))
|
||||
})
|
||||
|
||||
})
|
||||
@@ -31,13 +31,15 @@ type CryptoSetup struct {
|
||||
|
||||
cryptoStream utils.Stream
|
||||
|
||||
connectionParametersManager *ConnectionParametersManager
|
||||
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
var _ crypto.AEAD = &CryptoSetup{}
|
||||
|
||||
// NewCryptoSetup creates a new CryptoSetup instance
|
||||
func NewCryptoSetup(connID protocol.ConnectionID, version protocol.VersionNumber, scfg *ServerConfig, cryptoStream utils.Stream) *CryptoSetup {
|
||||
func NewCryptoSetup(connID protocol.ConnectionID, version protocol.VersionNumber, scfg *ServerConfig, cryptoStream utils.Stream, connectionParametersManager *ConnectionParametersManager) *CryptoSetup {
|
||||
nonce := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
panic(err)
|
||||
@@ -49,6 +51,7 @@ func NewCryptoSetup(connID protocol.ConnectionID, version protocol.VersionNumber
|
||||
nonce: nonce,
|
||||
keyDerivation: crypto.DeriveKeysChacha20,
|
||||
cryptoStream: cryptoStream,
|
||||
connectionParametersManager: connectionParametersManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -184,13 +187,19 @@ func (h *CryptoSetup) handleCHLO(data []byte, cryptoData map[Tag][]byte) ([]byte
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = h.connectionParametersManager.SetFromMap(cryptoData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
replyMap := h.connectionParametersManager.GetSHLOMap()
|
||||
// add crypto parameters
|
||||
replyMap[TagPUBS] = h.scfg.kex.PublicKey()
|
||||
replyMap[TagSNO] = h.nonce
|
||||
replyMap[TagVER] = protocol.SupportedVersionsAsTags
|
||||
|
||||
var reply bytes.Buffer
|
||||
WriteHandshakeMessage(&reply, TagSHLO, map[Tag][]byte{
|
||||
TagPUBS: h.scfg.kex.PublicKey(),
|
||||
TagSNO: h.nonce,
|
||||
TagVER: protocol.SupportedVersionsAsTags,
|
||||
TagICSL: []byte{0x1e, 0x00, 0x00, 0x00}, //30
|
||||
TagMSPC: []byte{0x64, 0x00, 0x00, 0x00}, //100
|
||||
})
|
||||
WriteHandshakeMessage(&reply, TagSHLO, replyMap)
|
||||
|
||||
return reply.Bytes(), nil
|
||||
}
|
||||
|
||||
@@ -89,6 +89,7 @@ var _ = Describe("Crypto setup", func() {
|
||||
scfg *ServerConfig
|
||||
cs *CryptoSetup
|
||||
stream *mockStream
|
||||
cpm *ConnectionParametersManager
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
@@ -97,7 +98,8 @@ var _ = Describe("Crypto setup", func() {
|
||||
signer = &mockSigner{}
|
||||
scfg = NewServerConfig(kex, signer)
|
||||
v := protocol.SupportedVersions[len(protocol.SupportedVersions)-1]
|
||||
cs = NewCryptoSetup(protocol.ConnectionID(42), v, scfg, stream)
|
||||
cpm = NewConnectionParamatersManager()
|
||||
cs = NewCryptoSetup(protocol.ConnectionID(42), v, scfg, stream, cpm)
|
||||
cs.keyDerivation = mockKeyDerivation
|
||||
})
|
||||
|
||||
|
||||
@@ -52,6 +52,8 @@ type Session struct {
|
||||
closeChan chan struct{}
|
||||
closed bool
|
||||
|
||||
connectionParametersManager *handshake.ConnectionParametersManager
|
||||
|
||||
// Used to calculate the next packet number from the truncated wire
|
||||
// representation, and sent back in public reset packets
|
||||
lastRcvdPacketNumber protocol.PacketNumber
|
||||
@@ -73,10 +75,11 @@ func NewSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
||||
receivedPackets: make(chan receivedPacket, 1000), // TODO: What if server receives many packets and connection is already closed?!
|
||||
closeChan: make(chan struct{}, 1),
|
||||
rttStats: congestion.RTTStats{},
|
||||
connectionParametersManager: handshake.NewConnectionParamatersManager(),
|
||||
}
|
||||
|
||||
cryptoStream, _ := session.NewStream(1)
|
||||
cryptoSetup := handshake.NewCryptoSetup(connectionID, v, sCfg, cryptoStream)
|
||||
cryptoSetup := handshake.NewCryptoSetup(connectionID, v, sCfg, cryptoStream, session.connectionParametersManager)
|
||||
|
||||
go func() {
|
||||
if err := cryptoSetup.HandleCryptoStream(); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user