diff --git a/handshake/crypto_setup_server.go b/handshake/crypto_setup_server.go index fe6fcef2..33bf49c3 100644 --- a/handshake/crypto_setup_server.go +++ b/handshake/crypto_setup_server.go @@ -334,6 +334,8 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T return nil, err } + h.aeadChanged <- protocol.EncryptionSecure + // Generate a new curve instance to derive the forward secure key var fsNonce bytes.Buffer fsNonce.Write(clientNonce) diff --git a/handshake/crypto_setup_server_test.go b/handshake/crypto_setup_server_test.go index 39313ff0..6232a151 100644 --- a/handshake/crypto_setup_server_test.go +++ b/handshake/crypto_setup_server_test.go @@ -336,7 +336,11 @@ var _ = Describe("Crypto setup", func() { Expect(err).NotTo(HaveOccurred()) Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO")) Expect(stream.dataWritten.Bytes()).ToNot(ContainSubstring("REJ")) - Expect(aeadChanged).To(Receive()) + var encLevel protocol.EncryptionLevel + Expect(aeadChanged).To(Receive(&encLevel)) + Expect(encLevel).To(Equal(protocol.EncryptionSecure)) + Expect(aeadChanged).To(Receive(&encLevel)) + Expect(encLevel).To(Equal(protocol.EncryptionForwardSecure)) }) It("recognizes inchoate CHLOs missing SCID", func() { diff --git a/packet_packer.go b/packet_packer.go index e1833ea4..6bf5515f 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -23,6 +23,8 @@ type packetPacker struct { perspective protocol.Perspective version protocol.VersionNumber cryptoSetup handshake.CryptoSetup + // as long as packets are not sent with forward-secure encryption, we limit the MaxPacketSize such that they can be retransmitted as a whole + isForwardSecure bool packetNumberGenerator *packetNumberGenerator @@ -105,6 +107,9 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, lea payloadFrames = []frames.Frame{p.controlFrames[0]} } else { maxSize := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength + if !p.isForwardSecure { + maxSize -= protocol.NonForwardSecurePacketSizeReduction + } payloadFrames, err = p.composeNextPacket(stopWaitingFrame, maxSize) if err != nil { return nil, err @@ -218,3 +223,7 @@ func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFra func (p *packetPacker) QueueControlFrameForNextPacket(f frames.Frame) { p.controlFrames = append(p.controlFrames, f) } + +func (p *packetPacker) SetForwardSecure() { + p.isForwardSecure = true +} diff --git a/packet_packer_test.go b/packet_packer_test.go index 24e9f389..c6c46d87 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -60,6 +60,7 @@ var _ = Describe("Packet packer", func() { publicHeaderLen = 1 + 8 + 2 // 1 flag byte, 8 connection ID, 2 packet number maxFrameSize = protocol.MaxFrameAndPublicHeaderSize - publicHeaderLen packer.version = protocol.Version34 + packer.isForwardSecure = true }) It("returns nil when no packet is queued", func() { @@ -314,6 +315,18 @@ var _ = Describe("Packet packer", func() { Expect(p.frames[0].(*frames.StreamFrame).DataLenPresent).To(BeFalse()) }) + It("packs smaller packets when it is not yet forward-secure", func() { + packer.isForwardSecure = false + f := &frames.StreamFrame{ + StreamID: 3, + Data: bytes.Repeat([]byte{'f'}, int(protocol.MaxPacketSize)), + } + streamFramer.AddFrameForRetransmission(f) + p, err := packer.PackPacket(nil, nil, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(p.raw).To(HaveLen(int(protocol.MaxPacketSize - protocol.NonForwardSecurePacketSizeReduction))) + }) + It("packs multiple small stream frames into single packet", func() { f1 := &frames.StreamFrame{ StreamID: 5, diff --git a/protocol/server_parameters.go b/protocol/server_parameters.go index d7173045..1b2c9b40 100644 --- a/protocol/server_parameters.go +++ b/protocol/server_parameters.go @@ -2,6 +2,10 @@ package protocol import "time" +// NonForwardSecurePacketSizeReduction is the number of bytes a non forward-secure packet has to be smaller than a forward-secure packet +// This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames +const NonForwardSecurePacketSizeReduction = 50 + // DefaultMaxCongestionWindow is the default for the max congestion window const DefaultMaxCongestionWindow = 1000 diff --git a/session.go b/session.go index 53271e78..c9d31239 100644 --- a/session.go +++ b/session.go @@ -236,6 +236,9 @@ runLoop: // begins with the public header and we never copy it. putPacketBuffer(p.publicHeader.Raw) case l := <-s.aeadChanged: + if l == protocol.EncryptionForwardSecure { + s.packer.SetForwardSecure() + } s.tryDecryptingQueuedPackets() s.cryptoChangeCallback(s, l == protocol.EncryptionForwardSecure) } diff --git a/session_test.go b/session_test.go index 519cc300..07b8eecb 100644 --- a/session_test.go +++ b/session_test.go @@ -1083,6 +1083,14 @@ var _ = Describe("Session", func() { }) }) + It("tells the packetPacker when forward-secure encryption is used", func() { + go sess.run() + sess.aeadChanged <- protocol.EncryptionSecure + Consistently(func() bool { return sess.packer.isForwardSecure }).Should(BeFalse()) + sess.aeadChanged <- protocol.EncryptionForwardSecure + Eventually(func() bool { return sess.packer.isForwardSecure }).Should(BeTrue()) + }) + It("closes when crypto stream errors", func() { go sess.run() s, err := sess.GetOrOpenStream(3)