From 2903f9b23893f401a95021296b04b2b026868d27 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Tue, 21 Mar 2017 16:13:00 +0700 Subject: [PATCH] return a lambda to for sealing a packet in the CryptoSetup --- handshake/crypto_setup_client.go | 68 ++++++++++++--------- handshake/crypto_setup_client_test.go | 40 +++++++------ handshake/crypto_setup_interface.go | 17 ------ handshake/crypto_setup_server.go | 61 ++++++++++--------- handshake/crypto_setup_server_test.go | 73 +++++++++++------------ handshake/interface.go | 19 ++++++ packet_packer.go | 48 +++++++-------- packet_packer_test.go | 85 +++++++++++++++++++-------- packet_unpacker.go | 1 - packet_unpacker_test.go | 2 +- 10 files changed, 230 insertions(+), 184 deletions(-) delete mode 100644 handshake/crypto_setup_interface.go create mode 100644 handshake/interface.go diff --git a/handshake/crypto_setup_client.go b/handshake/crypto_setup_client.go index 2148ae53..d287e6c5 100644 --- a/handshake/crypto_setup_client.go +++ b/handshake/crypto_setup_client.go @@ -92,7 +92,10 @@ func (h *cryptoSetupClient) HandleCryptoStream() error { } // send CHLOs until the forward secure encryption is established - if h.forwardSecureAEAD == nil { + h.mutex.RLock() + sendCHLO := h.forwardSecureAEAD == nil + h.mutex.RUnlock() + if sendCHLO { err = h.sendCHLO() if err != nil { return err @@ -276,6 +279,9 @@ func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool { } func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) { + h.mutex.RLock() + defer h.mutex.RUnlock() + if h.forwardSecureAEAD != nil { data, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData) if err == nil { @@ -302,36 +308,50 @@ func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNu return res, protocol.EncryptionUnencrypted, nil } -func (h *cryptoSetupClient) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel) { +func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) { + h.mutex.RLock() + defer h.mutex.RUnlock() + if h.forwardSecureAEAD != nil { - return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData), protocol.EncryptionForwardSecure + return protocol.EncryptionForwardSecure, h.sealForwardSecure + } else if h.secureAEAD != nil { + return protocol.EncryptionSecure, h.sealSecure + } else { + return protocol.EncryptionUnencrypted, h.sealUnencrypted } - if h.secureAEAD != nil { - return h.secureAEAD.Seal(dst, src, packetNumber, associatedData), protocol.EncryptionSecure - } - return (&crypto.NullAEAD{}).Seal(dst, src, packetNumber, associatedData), protocol.EncryptionUnencrypted } -func (h *cryptoSetupClient) SealWith(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte, forceEncryptionLevel protocol.EncryptionLevel) ([]byte, protocol.EncryptionLevel, error) { - switch forceEncryptionLevel { +func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) { + switch encLevel { case protocol.EncryptionUnencrypted: - return (&crypto.NullAEAD{}).Seal(dst, src, packetNumber, associatedData), protocol.EncryptionUnencrypted, nil + return h.sealUnencrypted, nil case protocol.EncryptionSecure: if h.secureAEAD == nil { - return nil, protocol.EncryptionUnspecified, errors.New("CryptoSetupClient: no secureAEAD") + return nil, errors.New("CryptoSetupClient: no secureAEAD") } - return h.secureAEAD.Seal(dst, src, packetNumber, associatedData), protocol.EncryptionSecure, nil + return h.sealSecure, nil case protocol.EncryptionForwardSecure: if h.forwardSecureAEAD == nil { - return nil, protocol.EncryptionUnspecified, errors.New("CryptoSetupClient: no forwardSecureAEAD") + return nil, errors.New("CryptoSetupClient: no forwardSecureAEAD") } - return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData), protocol.EncryptionForwardSecure, nil + return h.sealForwardSecure, nil } - - return nil, protocol.EncryptionUnspecified, errors.New("no encryption level specified") + return nil, errors.New("CryptoSetupClient: no encryption level specified") } -func (h *cryptoSetupClient) DiversificationNonce(bool) []byte { +func (h *cryptoSetupClient) sealUnencrypted(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { + return (&crypto.NullAEAD{}).Seal(dst, src, packetNumber, associatedData) +} + +func (h *cryptoSetupClient) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { + return h.secureAEAD.Seal(dst, src, packetNumber, associatedData) +} + +func (h *cryptoSetupClient) sealForwardSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { + return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData) +} + +func (h *cryptoSetupClient) DiversificationNonce() []byte { panic("not needed for cryptoSetupClient") } @@ -346,19 +366,11 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) error { return nil } -func (h *cryptoSetupClient) LockForSealing() { - -} - -func (h *cryptoSetupClient) UnlockForSealing() { - -} - func (h *cryptoSetupClient) HandshakeComplete() bool { h.mutex.RLock() - complete := h.forwardSecureAEAD != nil - h.mutex.RUnlock() - return complete + defer h.mutex.RUnlock() + + return h.forwardSecureAEAD != nil } func (h *cryptoSetupClient) sendCHLO() error { diff --git a/handshake/crypto_setup_client_test.go b/handshake/crypto_setup_client_test.go index b8a88519..98bc8fea 100644 --- a/handshake/crypto_setup_client_test.go +++ b/handshake/crypto_setup_client_test.go @@ -676,9 +676,10 @@ var _ = Describe("Crypto setup", func() { Context("null encryption", func() { It("is used initially", func() { - d, enc := cs.Seal(nil, []byte("foobar"), 0, []byte{}) - Expect(d).To(Equal(foobarFNVSigned)) + enc, seal := cs.GetSealer() Expect(enc).To(Equal(protocol.EncryptionUnencrypted)) + d := seal(nil, []byte("foobar"), 0, []byte{}) + Expect(d).To(Equal(foobarFNVSigned)) }) It("is accepted initially", func() { @@ -711,9 +712,10 @@ var _ = Describe("Crypto setup", func() { It("is used immediately when available", func() { doCompleteREJ() cs.receivedSecurePacket = false - d, enc := cs.Seal(nil, []byte("foobar"), 0, []byte{}) - Expect(d).To(Equal([]byte("foobar normal sec"))) + enc, seal := cs.GetSealer() Expect(enc).To(Equal(protocol.EncryptionSecure)) + d := seal(nil, []byte("foobar"), 0, []byte{}) + Expect(d).To(Equal([]byte("foobar normal sec"))) }) It("is accepted", func() { @@ -739,51 +741,53 @@ var _ = Describe("Crypto setup", func() { _, enc, err := cs.Open(nil, []byte("forward secure encrypted"), 0, []byte{}) Expect(err).ToNot(HaveOccurred()) Expect(enc).To(Equal(protocol.EncryptionForwardSecure)) - d, enc := cs.Seal(nil, []byte("foobar"), 0, []byte{}) - Expect(d).To(Equal([]byte("foobar forward sec"))) + enc, seal := cs.GetSealer() Expect(enc).To(Equal(protocol.EncryptionForwardSecure)) + d := seal(nil, []byte("foobar"), 0, []byte{}) + Expect(d).To(Equal([]byte("foobar forward sec"))) }) }) Context("forcing encryption levels", func() { It("forces null encryption", func() { - d, enc, err := cs.SealWith(nil, []byte("foobar"), 0, []byte{}, protocol.EncryptionUnencrypted) + seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnencrypted) Expect(err).ToNot(HaveOccurred()) + d := seal(nil, []byte("foobar"), 0, []byte{}) Expect(d).To(Equal(foobarFNVSigned)) - Expect(enc).To(Equal(protocol.EncryptionUnencrypted)) }) It("forces initial encryption", func() { doCompleteREJ() - d, enc, err := cs.SealWith(nil, []byte("foobar"), 0, []byte{}, protocol.EncryptionSecure) + seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure) Expect(err).ToNot(HaveOccurred()) + d := seal(nil, []byte("foobar"), 0, []byte{}) Expect(d).To(Equal([]byte("foobar normal sec"))) - Expect(enc).To(Equal(protocol.EncryptionSecure)) }) It("errors of no AEAD for initial encryption is available", func() { - _, enc, err := cs.SealWith(nil, []byte("foobar"), 0, []byte{}, protocol.EncryptionSecure) + seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure) Expect(err).To(MatchError("CryptoSetupClient: no secureAEAD")) - Expect(enc).To(Equal(protocol.EncryptionUnspecified)) + Expect(seal).To(BeNil()) }) It("forces forward-secure encryption", func() { doSHLO() - d, enc, err := cs.SealWith(nil, []byte("foobar"), 0, []byte{}, protocol.EncryptionForwardSecure) + seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure) Expect(err).ToNot(HaveOccurred()) + d := seal(nil, []byte("foobar"), 0, []byte{}) Expect(d).To(Equal([]byte("foobar forward sec"))) - Expect(enc).To(Equal(protocol.EncryptionForwardSecure)) }) It("errors of no AEAD for forward-secure encryption is available", func() { - _, enc, err := cs.SealWith(nil, []byte("foobar"), 0, []byte{}, protocol.EncryptionForwardSecure) + seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure) Expect(err).To(MatchError("CryptoSetupClient: no forwardSecureAEAD")) - Expect(enc).To(Equal(protocol.EncryptionUnspecified)) + Expect(seal).To(BeNil()) }) It("errors if no encryption level is specified", func() { - _, _, err := cs.SealWith(nil, []byte("foobar"), 0, []byte{}, protocol.EncryptionUnspecified) - Expect(err).To(MatchError("no encryption level specified")) + seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnspecified) + Expect(err).To(MatchError("CryptoSetupClient: no encryption level specified")) + Expect(seal).To(BeNil()) }) }) }) diff --git a/handshake/crypto_setup_interface.go b/handshake/crypto_setup_interface.go deleted file mode 100644 index 90e09bd4..00000000 --- a/handshake/crypto_setup_interface.go +++ /dev/null @@ -1,17 +0,0 @@ -package handshake - -import "github.com/lucas-clemente/quic-go/protocol" - -// CryptoSetup is a crypto setup -type CryptoSetup interface { - HandleCryptoStream() error - Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) - Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel) - SealWith(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte, forceEncryptionLevel protocol.EncryptionLevel) ([]byte, protocol.EncryptionLevel, error) - LockForSealing() - UnlockForSealing() - HandshakeComplete() bool - // TODO: clean up this interface - DiversificationNonce(force bool) []byte // only needed for cryptoSetupServer - SetDiversificationNonce([]byte) error // only needed for cryptoSetupClient -} diff --git a/handshake/crypto_setup_server.go b/handshake/crypto_setup_server.go index 2a1813d2..336cfe76 100644 --- a/handshake/crypto_setup_server.go +++ b/handshake/crypto_setup_server.go @@ -195,37 +195,49 @@ func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNu return res, protocol.EncryptionUnencrypted, err } -// Seal a message, call LockForSealing() before! -func (h *cryptoSetupServer) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel) { +func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) { + h.mutex.RLock() + defer h.mutex.RUnlock() + if h.forwardSecureAEAD != nil && h.sentSHLO { - return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData), protocol.EncryptionForwardSecure + return protocol.EncryptionForwardSecure, h.sealForwardSecure } else if h.secureAEAD != nil { // secureAEAD and forwardSecureAEAD are created at the same time (when receiving the CHLO) // make sure that the SHLO isn't sent forward-secure - h.sentSHLO = true - return h.secureAEAD.Seal(dst, src, packetNumber, associatedData), protocol.EncryptionSecure - } else { - return (&crypto.NullAEAD{}).Seal(dst, src, packetNumber, associatedData), protocol.EncryptionUnencrypted + return protocol.EncryptionSecure, h.sealSecure } + return protocol.EncryptionUnencrypted, h.sealUnencrypted } -func (h *cryptoSetupServer) SealWith(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte, forceEncryptionLevel protocol.EncryptionLevel) ([]byte, protocol.EncryptionLevel, error) { - switch forceEncryptionLevel { +func (h *cryptoSetupServer) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) { + switch encLevel { case protocol.EncryptionUnencrypted: - return (&crypto.NullAEAD{}).Seal(dst, src, packetNumber, associatedData), protocol.EncryptionUnencrypted, nil + return h.sealUnencrypted, nil case protocol.EncryptionSecure: if h.secureAEAD == nil { - return nil, protocol.EncryptionUnspecified, errors.New("CryptoSetupServer: no secureAEAD") + return nil, errors.New("CryptoSetupServer: no secureAEAD") } - return h.secureAEAD.Seal(dst, src, packetNumber, associatedData), protocol.EncryptionSecure, nil + return h.sealSecure, nil case protocol.EncryptionForwardSecure: if h.forwardSecureAEAD == nil { - return nil, protocol.EncryptionUnspecified, errors.New("CryptoSetupServer: no forwardSecureAEAD") + return nil, errors.New("CryptoSetupServer: no forwardSecureAEAD") } - return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData), protocol.EncryptionForwardSecure, nil + return h.sealForwardSecure, nil } + return nil, errors.New("CryptoSetupServer: no encryption level specified") +} - return nil, protocol.EncryptionUnspecified, errors.New("no encryption level specified") +func (h *cryptoSetupServer) sealUnencrypted(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { + return (&crypto.NullAEAD{}).Seal(dst, src, packetNumber, associatedData) +} + +func (h *cryptoSetupServer) sealSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { + h.sentSHLO = true + return h.secureAEAD.Seal(dst, src, packetNumber, associatedData) +} + +func (h *cryptoSetupServer) sealForwardSecure(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { + return h.forwardSecureAEAD.Seal(dst, src, packetNumber, associatedData) } func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byte) bool { @@ -398,28 +410,15 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T return reply.Bytes(), nil } -// DiversificationNonce returns a diversification nonce if required in the next packet to be Seal'ed. See LockForSealing()! -func (h *cryptoSetupServer) DiversificationNonce(force bool) []byte { - if force || (h.secureAEAD != nil && !h.sentSHLO) { - return h.diversificationNonce - } - return nil +// DiversificationNonce returns the diversification nonce +func (h *cryptoSetupServer) DiversificationNonce() []byte { + return h.diversificationNonce } func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) error { panic("not needed for cryptoSetupServer") } -// LockForSealing should be called before Seal(). It is needed so that diversification nonces can be obtained before packets are sealed, and the AEADs are not changed in the meantime. -func (h *cryptoSetupServer) LockForSealing() { - h.mutex.RLock() -} - -// UnlockForSealing should be called after Seal() is complete, see LockForSealing(). -func (h *cryptoSetupServer) UnlockForSealing() { - h.mutex.RUnlock() -} - // HandshakeComplete returns true after the first forward secure packet was received form the client. func (h *cryptoSetupServer) HandshakeComplete() bool { return h.receivedForwardSecurePacket diff --git a/handshake/crypto_setup_server_test.go b/handshake/crypto_setup_server_test.go index ba250e1d..dddee851 100644 --- a/handshake/crypto_setup_server_test.go +++ b/handshake/crypto_setup_server_test.go @@ -185,28 +185,13 @@ var _ = Describe("Crypto setup", func() { cs.secureAEAD = &mockAEAD{} cs.receivedForwardSecurePacket = false - Expect(cs.DiversificationNonce(false)).To(BeEmpty()) + Expect(cs.DiversificationNonce()).To(BeEmpty()) // Div nonce is created after CHLO cs.handleCHLO("", nil, map[Tag][]byte{TagNONC: nonce32}) }) It("returns diversification nonces", func() { - Expect(cs.DiversificationNonce(false)).To(HaveLen(32)) - }) - - It("does not return nonce after sending the SHLO", func() { - cs.sentSHLO = true - Expect(cs.DiversificationNonce(false)).To(BeEmpty()) - }) - - It("returns a nonce for a retransmission, even after sending the SHLO", func() { - cs.sentSHLO = true - Expect(cs.DiversificationNonce(true)).To(HaveLen(32)) - }) - - It("does not return nonce for unencrypted packets", func() { - cs.secureAEAD = nil - Expect(cs.DiversificationNonce(false)).To(BeEmpty()) + Expect(cs.DiversificationNonce()).To(HaveLen(32)) }) }) @@ -590,9 +575,10 @@ var _ = Describe("Crypto setup", func() { Context("null encryption", func() { It("is used initially", func() { - d, enc := cs.Seal(nil, []byte("foobar"), 0, []byte{}) - Expect(d).To(Equal(foobarFNVSigned)) + enc, seal := cs.GetSealer() Expect(enc).To(Equal(protocol.EncryptionUnencrypted)) + d := seal(nil, []byte("foobar"), 0, []byte{}) + Expect(d).To(Equal(foobarFNVSigned)) }) It("is accepted initially", func() { @@ -624,18 +610,20 @@ var _ = Describe("Crypto setup", func() { It("is not used after CHLO", func() { doCHLO() - d, enc := cs.Seal(nil, []byte("foobar"), 0, []byte{}) - Expect(d).ToNot(Equal(foobarFNVSigned)) + enc, seal := cs.GetSealer() Expect(enc).ToNot(Equal(protocol.EncryptionUnencrypted)) + d := seal(nil, []byte("foobar"), 0, []byte{}) + Expect(d).ToNot(Equal(foobarFNVSigned)) }) }) Context("initial encryption", func() { It("is used after CHLO", func() { doCHLO() - d, enc := cs.Seal(nil, []byte("foobar"), 0, []byte{}) - Expect(d).To(Equal([]byte("foobar normal sec"))) + enc, seal := cs.GetSealer() Expect(enc).To(Equal(protocol.EncryptionSecure)) + d := seal(nil, []byte("foobar"), 0, []byte{}) + Expect(d).To(Equal([]byte("foobar normal sec"))) }) It("is accepted after CHLO", func() { @@ -659,19 +647,23 @@ var _ = Describe("Crypto setup", func() { Context("forward secure encryption", func() { It("is used after sending out one packet with initial encryption", func() { doCHLO() - _, enc := cs.Seal(nil, []byte("SHLO"), 0, []byte{}) + enc, seal := cs.GetSealer() Expect(enc).To(Equal(protocol.EncryptionSecure)) - d, enc := cs.Seal(nil, []byte("foobar"), 0, []byte{}) - Expect(d).To(Equal([]byte("foobar forward sec"))) + _ = seal(nil, []byte("SHLO"), 0, []byte{}) + enc, seal = cs.GetSealer() Expect(enc).To(Equal(protocol.EncryptionForwardSecure)) + d := seal(nil, []byte("foobar"), 0, []byte{}) + Expect(d).To(Equal([]byte("foobar forward sec"))) }) It("regards the handshake as complete once it receives a forward encrypted packet", func() { doCHLO() - _, enc := cs.Seal(nil, []byte("SHLO"), 0, []byte{}) + enc, seal := cs.GetSealer() Expect(enc).To(Equal(protocol.EncryptionSecure)) - _, enc = cs.Seal(nil, []byte("foobar"), 0, []byte{}) + _ = seal(nil, []byte("SHLO"), 0, []byte{}) + enc, seal = cs.GetSealer() Expect(enc).To(Equal(protocol.EncryptionForwardSecure)) + _ = seal(nil, []byte("foobar"), 0, []byte{}) Expect(cs.HandshakeComplete()).To(BeFalse()) cs.receivedForwardSecurePacket = true Expect(cs.HandshakeComplete()).To(BeTrue()) @@ -680,43 +672,44 @@ var _ = Describe("Crypto setup", func() { Context("forcing encryption levels", func() { It("forces null encryption", func() { - d, enc, err := cs.SealWith(nil, []byte("foobar"), 0, []byte{}, protocol.EncryptionUnencrypted) + seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnencrypted) Expect(err).ToNot(HaveOccurred()) + d := seal(nil, []byte("foobar"), 0, []byte{}) Expect(d).To(Equal(foobarFNVSigned)) - Expect(enc).To(Equal(protocol.EncryptionUnencrypted)) }) It("forces initial encryption", func() { doCHLO() - d, enc, err := cs.SealWith(nil, []byte("foobar"), 0, []byte{}, protocol.EncryptionSecure) + seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure) Expect(err).ToNot(HaveOccurred()) + d := seal(nil, []byte("foobar"), 0, []byte{}) Expect(d).To(Equal([]byte("foobar normal sec"))) - Expect(enc).To(Equal(protocol.EncryptionSecure)) }) It("errors of no AEAD for initial encryption is available", func() { - _, enc, err := cs.SealWith(nil, []byte("foobar"), 0, []byte{}, protocol.EncryptionSecure) + seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionSecure) Expect(err).To(MatchError("CryptoSetupServer: no secureAEAD")) - Expect(enc).To(Equal(protocol.EncryptionUnspecified)) + Expect(seal).To(BeNil()) }) It("forces forward-secure encryption", func() { doCHLO() - d, enc, err := cs.SealWith(nil, []byte("foobar"), 0, []byte{}, protocol.EncryptionForwardSecure) + seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure) Expect(err).ToNot(HaveOccurred()) + d := seal(nil, []byte("foobar"), 0, []byte{}) Expect(d).To(Equal([]byte("foobar forward sec"))) - Expect(enc).To(Equal(protocol.EncryptionForwardSecure)) }) It("errors of no AEAD for forward-secure encryption is available", func() { - _, enc, err := cs.SealWith(nil, []byte("foobar"), 0, []byte{}, protocol.EncryptionForwardSecure) + seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionForwardSecure) Expect(err).To(MatchError("CryptoSetupServer: no forwardSecureAEAD")) - Expect(enc).To(Equal(protocol.EncryptionUnspecified)) + Expect(seal).To(BeNil()) }) It("errors if no encryption level is specified", func() { - _, _, err := cs.SealWith(nil, []byte("foobar"), 0, []byte{}, protocol.EncryptionUnspecified) - Expect(err).To(MatchError("no encryption level specified")) + seal, err := cs.GetSealerWithEncryptionLevel(protocol.EncryptionUnspecified) + Expect(err).To(MatchError("CryptoSetupServer: no encryption level specified")) + Expect(seal).To(BeNil()) }) }) }) diff --git a/handshake/interface.go b/handshake/interface.go new file mode 100644 index 00000000..a6629cc4 --- /dev/null +++ b/handshake/interface.go @@ -0,0 +1,19 @@ +package handshake + +import "github.com/lucas-clemente/quic-go/protocol" + +// Sealer seals a packet +type Sealer func(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte + +// CryptoSetup is a crypto setup +type CryptoSetup interface { + Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) + HandleCryptoStream() error + HandshakeComplete() bool + // TODO: clean up this interface + DiversificationNonce() []byte // only needed for cryptoSetupServer + SetDiversificationNonce([]byte) error // only needed for cryptoSetupClient + + GetSealer() (protocol.EncryptionLevel, Sealer) + GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) +} diff --git a/packet_packer.go b/packet_packer.go index 10161899..08db29e6 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -75,13 +75,23 @@ func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, con return p.packPacket(stopWaitingFrame, leastUnacked, nil) } -func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, leastUnacked protocol.PacketNumber, packetToRetransmit *ackhandler.Packet) (*packedPacket, error) { - // packetToRetransmit is only set for handshake retransmissions - isHandshakeRetransmission := (packetToRetransmit != nil) - // cryptoSetup needs to be locked here, so that the AEADs are not changed between - // calling DiversificationNonce() and Seal(). - p.cryptoSetup.LockForSealing() - defer p.cryptoSetup.UnlockForSealing() +func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, leastUnacked protocol.PacketNumber, handshakePacketToRetransmit *ackhandler.Packet) (*packedPacket, error) { + // handshakePacketToRetransmit is only set for handshake retransmissions + isHandshakeRetransmission := (handshakePacketToRetransmit != nil) + + var sealFunc handshake.Sealer + var encLevel protocol.EncryptionLevel + + if isHandshakeRetransmission { + var err error + encLevel = handshakePacketToRetransmit.EncryptionLevel + sealFunc, err = p.cryptoSetup.GetSealerWithEncryptionLevel(encLevel) + if err != nil { + return nil, err + } + } else { + encLevel, sealFunc = p.cryptoSetup.GetSealer() + } currentPacketNumber := p.packetNumberGenerator.Peek() packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(currentPacketNumber, leastUnacked) @@ -92,12 +102,11 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, lea TruncateConnectionID: p.connectionParameters.TruncateConnectionID(), } - if p.perspective == protocol.PerspectiveServer { - force := isHandshakeRetransmission && (packetToRetransmit.EncryptionLevel == protocol.EncryptionSecure) - responsePublicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce(force) + if p.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionSecure { + responsePublicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce() } - if p.perspective == protocol.PerspectiveClient && !p.cryptoSetup.HandshakeComplete() { + if p.perspective == protocol.PerspectiveClient && encLevel != protocol.EncryptionForwardSecure { responsePublicHeader.VersionFlag = true responsePublicHeader.VersionNumber = p.version } @@ -122,7 +131,7 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, lea if isHandshakeRetransmission { payloadFrames = append(payloadFrames, stopWaitingFrame) // don't retransmit Acks and StopWaitings - for _, f := range packetToRetransmit.Frames { + for _, f := range handshakePacketToRetransmit.Frames { switch f.(type) { case *frames.AckFrame: continue @@ -178,19 +187,10 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, lea } raw = raw[0:buffer.Len()] - var encryptionLevel protocol.EncryptionLevel - if isHandshakeRetransmission { - var err error - _, encryptionLevel, err = p.cryptoSetup.SealWith(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], currentPacketNumber, raw[:payloadStartIndex], packetToRetransmit.EncryptionLevel) - if err != nil { - return nil, err - } - } else { - _, encryptionLevel = p.cryptoSetup.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], currentPacketNumber, raw[:payloadStartIndex]) - } + _ = sealFunc(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], currentPacketNumber, raw[:payloadStartIndex]) raw = raw[0 : buffer.Len()+12] - if hasNonCryptoStreamData && encryptionLevel <= protocol.EncryptionUnencrypted { + if hasNonCryptoStreamData && encLevel <= protocol.EncryptionUnencrypted { return nil, qerr.AttemptToSendUnencryptedStreamData } @@ -203,7 +203,7 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, lea number: currentPacketNumber, raw: raw, frames: payloadFrames, - encryptionLevel: encryptionLevel, + encryptionLevel: encLevel, }, nil } diff --git a/packet_packer_test.go b/packet_packer_test.go index 20b236b8..e7081b7d 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -5,6 +5,7 @@ import ( "github.com/lucas-clemente/quic-go/ackhandler" "github.com/lucas-clemente/quic-go/frames" + "github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/qerr" . "github.com/onsi/ginkgo" @@ -13,7 +14,6 @@ import ( type mockCryptoSetup struct { divNonce []byte - forcedDivNonce bool handshakeComplete bool encLevelSeal protocol.EncryptionLevel } @@ -23,21 +23,26 @@ func (m *mockCryptoSetup) HandleCryptoStream() error { return nil } func (m *mockCryptoSetup) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) { return nil, protocol.EncryptionUnspecified, nil } -func (m *mockCryptoSetup) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel) { - return append(src, bytes.Repeat([]byte{0}, 12)...), m.encLevelSeal +func (m *mockCryptoSetup) GetSealer() (protocol.EncryptionLevel, handshake.Sealer) { + return m.encLevelSeal, func(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { + return append(src, bytes.Repeat([]byte{0}, 12)...) + } } -func (m *mockCryptoSetup) SealWith(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte, encLevel protocol.EncryptionLevel) ([]byte, protocol.EncryptionLevel, error) { - return append(src, bytes.Repeat([]byte{0}, 12)...), encLevel, nil +func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (handshake.Sealer, error) { + return func(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte { + return append(src, bytes.Repeat([]byte{0}, 12)...) + }, nil } func (m *mockCryptoSetup) LockForSealing() {} func (m *mockCryptoSetup) UnlockForSealing() {} func (m *mockCryptoSetup) HandshakeComplete() bool { return m.handshakeComplete } -func (m *mockCryptoSetup) DiversificationNonce(force bool) []byte { - m.forcedDivNonce = force +func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce } func (m *mockCryptoSetup) SetDiversificationNonce([]byte) error { panic("not implemented") } +var _ handshake.CryptoSetup = &mockCryptoSetup{} + var _ = Describe("Packet packer", func() { var ( packer *packetPacker @@ -102,17 +107,47 @@ var _ = Describe("Packet packer", func() { Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) }) - It("includes a diversification nonce, when acting as a server", func() { - nonce := bytes.Repeat([]byte{'e'}, 32) - packer.cryptoSetup.(*mockCryptoSetup).divNonce = nonce - f := &frames.StreamFrame{ - StreamID: 5, - Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, - } - streamFramer.AddFrameForRetransmission(f) - p, err := packer.PackPacket(nil, []frames.Frame{}, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(p.raw).To(ContainSubstring(string(nonce))) + Context("diversificaton nonces", func() { + var nonce []byte + + BeforeEach(func() { + nonce = bytes.Repeat([]byte{'e'}, 32) + packer.cryptoSetup.(*mockCryptoSetup).divNonce = nonce + f := &frames.StreamFrame{ + StreamID: 1, + Data: []byte{0xDE, 0xCA, 0xFB, 0xAD}, + } + streamFramer.AddFrameForRetransmission(f) + }) + + It("doesn't include a div nonce, when sending a packet with initial encryption", func() { + packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted + p, err := packer.PackPacket(nil, []frames.Frame{}, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(p.raw).ToNot(ContainSubstring(string(nonce))) + }) + + It("includes a div nonce, when sending a packet with secure encryption", func() { + packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure + p, err := packer.PackPacket(nil, []frames.Frame{}, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(p.raw).To(ContainSubstring(string(nonce))) + }) + + It("doesn't include a div nonce, when sending a packet with forward-secure encryption", func() { + packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionUnencrypted + p, err := packer.PackPacket(nil, []frames.Frame{}, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(p.raw).ToNot(ContainSubstring(string(nonce))) + }) + + It("doesn't send a div nonce as a client", func() { + packer.perspective = protocol.PerspectiveClient + packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure + p, err := packer.PackPacket(nil, []frames.Frame{}, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(p.raw).ToNot(ContainSubstring(string(nonce))) + }) }) It("packs a ConnectionClose", func() { @@ -195,7 +230,7 @@ var _ = Describe("Packet packer", func() { It("adds the version flag to the public header before the crypto handshake is finished", func() { packer.perspective = protocol.PerspectiveClient - packer.cryptoSetup.(*mockCryptoSetup).handshakeComplete = false + packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionSecure packer.controlFrames = []frames.Frame{&frames.BlockedFrame{StreamID: 0}} packer.connectionID = 0x1337 packer.version = 123 @@ -208,9 +243,9 @@ var _ = Describe("Packet packer", func() { Expect(hdr.VersionNumber).To(Equal(packer.version)) }) - It("doesn't add the version flag to the public header after the crypto handshake is completed", func() { + It("doesn't add the version flag to the public header for forward-secure packets", func() { packer.perspective = protocol.PerspectiveClient - packer.cryptoSetup.(*mockCryptoSetup).handshakeComplete = true + packer.cryptoSetup.(*mockCryptoSetup).encLevelSeal = protocol.EncryptionForwardSecure packer.controlFrames = []frames.Frame{&frames.BlockedFrame{StreamID: 0}} packer.connectionID = 0x1337 p, err := packer.PackPacket(nil, []frames.Frame{}, 0) @@ -593,11 +628,11 @@ var _ = Describe("Packet packer", func() { Expect(p.frames).To(ContainElement(sf)) Expect(p.frames).To(ContainElement(swf)) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionUnencrypted)) - // unencrypted packets don't need a diversification nonce - Expect(packer.cryptoSetup.(*mockCryptoSetup).forcedDivNonce).To(BeFalse()) }) It("packs a retransmission for a packet sent with initial encryption", func() { + nonce := bytes.Repeat([]byte{'e'}, 32) + packer.cryptoSetup.(*mockCryptoSetup).divNonce = nonce packet := &ackhandler.Packet{ EncryptionLevel: protocol.EncryptionSecure, Frames: []frames.Frame{sf}, @@ -607,6 +642,9 @@ var _ = Describe("Packet packer", func() { Expect(p.frames).To(ContainElement(sf)) Expect(p.frames).To(ContainElement(swf)) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) + // a packet sent by the server with initial encryption contains the SHLO + // it needs to have a diversification nonce + Expect(p.raw).To(ContainSubstring(string(nonce))) }) It("includes the diversification nonce on packets sent with initial encryption", func() { @@ -617,7 +655,6 @@ var _ = Describe("Packet packer", func() { p, err := packer.RetransmitNonForwardSecurePacket(swf, packet) Expect(err).ToNot(HaveOccurred()) Expect(p.encryptionLevel).To(Equal(protocol.EncryptionSecure)) - Expect(packer.cryptoSetup.(*mockCryptoSetup).forcedDivNonce).To(BeTrue()) }) It("removes non-retransmittable frames", func() { diff --git a/packet_unpacker.go b/packet_unpacker.go index 350faa70..30dee80a 100644 --- a/packet_unpacker.go +++ b/packet_unpacker.go @@ -12,7 +12,6 @@ import ( type quicAEAD interface { Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) - Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel) } type packetUnpacker struct { diff --git a/packet_unpacker_test.go b/packet_unpacker_test.go index acd07077..90d4d083 100644 --- a/packet_unpacker_test.go +++ b/packet_unpacker_test.go @@ -47,7 +47,7 @@ var _ = Describe("Packet unpacker", func() { }) setData := func(p []byte) { - data, _ = unpacker.aead.Seal(nil, p, 0, hdrBin) + data, _ = unpacker.aead.(*mockAEAD).Seal(nil, p, 0, hdrBin) } It("does not read read a private flag for QUIC Version >= 34", func() {