use separate functions per encryption level to get sealers

This commit is contained in:
Marten Seemann
2019-06-10 15:16:25 +08:00
parent d4d3f09ee3
commit c503769bcd
6 changed files with 180 additions and 124 deletions

View File

@@ -564,41 +564,31 @@ func (h *cryptoSetup) SendAlert(alert uint8) {
h.alertChan <- alert
}
func (h *cryptoSetup) GetSealer() (protocol.EncryptionLevel, Sealer) {
func (h *cryptoSetup) GetInitialSealer() (Sealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.sealer != nil {
return protocol.Encryption1RTT, h.sealer
}
if h.handshakeSealer != nil {
return protocol.EncryptionHandshake, h.handshakeSealer
}
return protocol.EncryptionInitial, h.initialSealer
return h.initialSealer, nil
}
func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLevel) (Sealer, error) {
errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", level.String())
func (h *cryptoSetup) GetHandshakeSealer() (Sealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
switch level {
case protocol.EncryptionInitial:
return h.initialSealer, nil
case protocol.EncryptionHandshake:
if h.handshakeSealer == nil {
return nil, errNoSealer
}
return h.handshakeSealer, nil
case protocol.Encryption1RTT:
if h.sealer == nil {
return nil, errNoSealer
}
return h.sealer, nil
default:
return nil, errNoSealer
if h.handshakeSealer == nil {
return nil, errors.New("CryptoSetup: no sealer with encryption level Handshake")
}
return h.handshakeSealer, nil
}
func (h *cryptoSetup) Get1RTTSealer() (Sealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.sealer == nil {
return nil, errors.New("CryptoSetup: no sealer with encryption level 1-RTT")
}
return h.sealer, nil
}
func (h *cryptoSetup) GetInitialOpener() (Opener, error) {

View File

@@ -49,6 +49,7 @@ type CryptoSetup interface {
GetHandshakeOpener() (Opener, error)
Get1RTTOpener() (Opener, error)
GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
GetInitialSealer() (Sealer, error)
GetHandshakeSealer() (Sealer, error)
Get1RTTSealer() (Sealer, error)
}