From 98233f674325717d8e0c34922f10f8921edb552c Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 22 Apr 2020 17:00:28 +0700 Subject: [PATCH 1/4] add a way to pack coalesced packets smaller than the usual packet size --- mock_packer_test.go | 8 ++-- packet_packer.go | 46 ++++++++++++----------- packet_packer_test.go | 86 +++++++++++++++++++++++++++++++++++-------- session.go | 2 +- session_test.go | 24 ++++++------ 5 files changed, 113 insertions(+), 53 deletions(-) diff --git a/mock_packer_test.go b/mock_packer_test.go index d05be9c00..6f95fb1a9 100644 --- a/mock_packer_test.go +++ b/mock_packer_test.go @@ -79,18 +79,18 @@ func (mr *MockPackerMockRecorder) MaybePackProbePacket(arg0 interface{}) *gomock } // PackCoalescedPacket mocks base method -func (m *MockPacker) PackCoalescedPacket() (*coalescedPacket, error) { +func (m *MockPacker) PackCoalescedPacket(arg0 protocol.ByteCount) (*coalescedPacket, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PackCoalescedPacket") + ret := m.ctrl.Call(m, "PackCoalescedPacket", arg0) ret0, _ := ret[0].(*coalescedPacket) ret1, _ := ret[1].(error) return ret0, ret1 } // PackCoalescedPacket indicates an expected call of PackCoalescedPacket -func (mr *MockPackerMockRecorder) PackCoalescedPacket() *gomock.Call { +func (mr *MockPackerMockRecorder) PackCoalescedPacket(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackCoalescedPacket", reflect.TypeOf((*MockPacker)(nil).PackCoalescedPacket)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PackCoalescedPacket", reflect.TypeOf((*MockPacker)(nil).PackCoalescedPacket), arg0) } // PackConnectionClose mocks base method diff --git a/packet_packer.go b/packet_packer.go index c6fa01996..92a771c19 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -17,7 +17,7 @@ import ( ) type packer interface { - PackCoalescedPacket() (*coalescedPacket, error) + PackCoalescedPacket(protocol.ByteCount) (*coalescedPacket, error) PackPacket() (*packedPacket, error) MaybePackProbePacket(protocol.EncryptionLevel) (*packedPacket, error) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) @@ -323,14 +323,14 @@ func (p *packetPacker) padPacket(buffer *packetBuffer) { // PackCoalescedPacket packs a new packet. // It packs an Initial / Handshake if there is data to send in these packet number spaces. // It should only be called before the handshake is confirmed. -func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { +func (p *packetPacker) PackCoalescedPacket(maxPacketSize protocol.ByteCount) (*coalescedPacket, error) { buffer := getPacketBuffer() - packet, err := p.packCoalescedPacket(buffer) + packet, err := p.packCoalescedPacket(buffer, maxPacketSize) if err != nil { return nil, err } - if len(packet.packets) == 0 { // nothing to send + if packet == nil || len(packet.packets) == 0 { // nothing to send buffer.Release() return nil, nil } @@ -342,37 +342,45 @@ func (p *packetPacker) PackCoalescedPacket() (*coalescedPacket, error) { return packet, nil } -func (p *packetPacker) packCoalescedPacket(buffer *packetBuffer) (*coalescedPacket, error) { +func (p *packetPacker) packCoalescedPacket(buffer *packetBuffer, maxPacketSize protocol.ByteCount) (*coalescedPacket, error) { + maxPacketSize = utils.MinByteCount(maxPacketSize, p.maxPacketSize) + if p.perspective == protocol.PerspectiveClient { + maxPacketSize = protocol.MinInitialPacketSize + } + if maxPacketSize < protocol.MinCoalescedPacketSize { + return nil, nil + } + packet := &coalescedPacket{ buffer: buffer, packets: make([]*packetContents, 0, 3), } // Try packing an Initial packet. - contents, err := p.maybeAppendCryptoPacket(buffer, protocol.EncryptionInitial) + contents, err := p.maybeAppendCryptoPacket(buffer, maxPacketSize, protocol.EncryptionInitial) if err != nil && err != handshake.ErrKeysDropped { return nil, err } if contents != nil { packet.packets = append(packet.packets, contents) } - if buffer.Len() >= p.maxPacketSize-protocol.MinCoalescedPacketSize { + if buffer.Len() >= maxPacketSize-protocol.MinCoalescedPacketSize { return packet, nil } // Add a Handshake packet. - contents, err = p.maybeAppendCryptoPacket(buffer, protocol.EncryptionHandshake) + contents, err = p.maybeAppendCryptoPacket(buffer, maxPacketSize, protocol.EncryptionHandshake) if err != nil && err != handshake.ErrKeysDropped && err != handshake.ErrKeysNotYetAvailable { return nil, err } if contents != nil { packet.packets = append(packet.packets, contents) } - if buffer.Len() >= p.maxPacketSize-protocol.MinCoalescedPacketSize { + if buffer.Len() >= maxPacketSize-protocol.MinCoalescedPacketSize { return packet, nil } // Add a 0-RTT / 1-RTT packet. - contents, err = p.maybeAppendAppDataPacket(buffer) + contents, err = p.maybeAppendAppDataPacket(buffer, maxPacketSize) if err == handshake.ErrKeysNotYetAvailable { return packet, nil } @@ -389,7 +397,7 @@ func (p *packetPacker) packCoalescedPacket(buffer *packetBuffer) (*coalescedPack // It should be called after the handshake is confirmed. func (p *packetPacker) PackPacket() (*packedPacket, error) { buffer := getPacketBuffer() - contents, err := p.maybeAppendAppDataPacket(buffer) + contents, err := p.maybeAppendAppDataPacket(buffer, p.maxPacketSize) if err != nil || contents == nil { buffer.Release() return nil, err @@ -400,16 +408,12 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { }, nil } -func (p *packetPacker) maybeAppendCryptoPacket(buffer *packetBuffer, encLevel protocol.EncryptionLevel) (*packetContents, error) { +func (p *packetPacker) maybeAppendCryptoPacket(buffer *packetBuffer, maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel) (*packetContents, error) { var sealer sealer var s cryptoStream var hasRetransmission bool - maxPacketSize := p.maxPacketSize switch encLevel { case protocol.EncryptionInitial: - if p.perspective == protocol.PerspectiveClient { - maxPacketSize = protocol.MinInitialPacketSize - } s = p.initialStream hasRetransmission = p.retransmissionQueue.HasInitialData() var err error @@ -471,7 +475,7 @@ func (p *packetPacker) maybeAppendCryptoPacket(buffer *packetBuffer, encLevel pr return p.appendPacket(buffer, hdr, payload, encLevel, sealer) } -func (p *packetPacker) maybeAppendAppDataPacket(buffer *packetBuffer) (*packetContents, error) { +func (p *packetPacker) maybeAppendAppDataPacket(buffer *packetBuffer, maxPacketSize protocol.ByteCount) (*packetContents, error) { var sealer sealer var header *wire.ExtendedHeader var encLevel protocol.EncryptionLevel @@ -494,7 +498,7 @@ func (p *packetPacker) maybeAppendAppDataPacket(buffer *packetBuffer) (*packetCo } headerLen := header.GetLength(p.version) - maxSize := p.maxPacketSize - buffer.Len() - protocol.ByteCount(sealer.Overhead()) - headerLen + maxSize := maxPacketSize - buffer.Len() - protocol.ByteCount(sealer.Overhead()) - headerLen payload := p.composeNextPacket(maxSize, encLevel != protocol.Encryption0RTT && buffer.Len() == 0) // check if we have anything to send @@ -557,11 +561,11 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( buffer := getPacketBuffer() switch encLevel { case protocol.EncryptionInitial: - contents, err = p.maybeAppendCryptoPacket(buffer, protocol.EncryptionInitial) + contents, err = p.maybeAppendCryptoPacket(buffer, p.maxPacketSize, protocol.EncryptionInitial) case protocol.EncryptionHandshake: - contents, err = p.maybeAppendCryptoPacket(buffer, protocol.EncryptionHandshake) + contents, err = p.maybeAppendCryptoPacket(buffer, p.maxPacketSize, protocol.EncryptionHandshake) case protocol.Encryption1RTT: - contents, err = p.maybeAppendAppDataPacket(buffer) + contents, err = p.maybeAppendAppDataPacket(buffer, p.maxPacketSize) default: panic("unknown encryption level") } diff --git a/packet_packer_test.go b/packet_packer_test.go index d0a435f52..69aacaba2 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -176,7 +176,7 @@ var _ = Describe("Packet packer", func() { expectAppendControlFrames() f := &wire.StreamFrame{Data: []byte{0xde, 0xca, 0xfb, 0xad}} expectAppendStreamFrames(ackhandler.Frame{Frame: f}) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(err).ToNot(HaveOccurred()) Expect(p).ToNot(BeNil()) Expect(p.packets).To(HaveLen(1)) @@ -266,7 +266,7 @@ var _ = Describe("Packet packer", func() { framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(frames []ackhandler.Frame, _ protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { return frames, 0 }) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(p).ToNot(BeNil()) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) @@ -535,7 +535,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetHandshakeSealer().Return(sealer, nil) sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) packer.retransmissionQueue.AddHandshake(&wire.PingFrame{}) - packet, err := packer.PackCoalescedPacket() + packet, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(err).ToNot(HaveOccurred()) Expect(packet).ToNot(BeNil()) Expect(packet.packets).To(HaveLen(1)) @@ -784,7 +784,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) sealingManager.EXPECT().GetHandshakeSealer().Return(nil, handshake.ErrKeysNotYetAvailable) sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(err).ToNot(HaveOccurred()) checkLength(p.buffer.Data) }) @@ -805,7 +805,7 @@ var _ = Describe("Packet packer", func() { Expect(f.Length(packer.version)).To(Equal(size)) return f }) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].frames).To(HaveLen(1)) @@ -832,7 +832,7 @@ var _ = Describe("Packet packer", func() { handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(size protocol.ByteCount) *wire.CryptoFrame { return &wire.CryptoFrame{Offset: 0x1337, Data: []byte("handshake")} }) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(2)) Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) @@ -868,7 +868,7 @@ var _ = Describe("Packet packer", func() { }) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: &wire.StreamFrame{Data: []byte("foobar")}}) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Data).To(HaveLen(protocol.MinInitialPacketSize)) Expect(p.packets).To(HaveLen(2)) @@ -903,7 +903,7 @@ var _ = Describe("Packet packer", func() { }) expectAppendControlFrames() expectAppendStreamFrames(ackhandler.Frame{Frame: &wire.StreamFrame{Data: []byte("foobar")}}) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(2)) Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionHandshake)) @@ -935,7 +935,7 @@ var _ = Describe("Packet packer", func() { Expect(f.Length(packer.version)).To(Equal(s)) return f }) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) @@ -943,6 +943,62 @@ var _ = Describe("Packet packer", func() { checkLength(p.buffer.Data) }) + It("doesn't pack a coalesced packet if there's not enough space", func() { + p, err := packer.PackCoalescedPacket(protocol.MinCoalescedPacketSize - 1) + Expect(err).ToNot(HaveOccurred()) + Expect(p).To(BeNil()) + }) + + It("packs a small packet", func() { + const size = protocol.MinCoalescedPacketSize + 10 + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x24)) + sealingManager.EXPECT().GetInitialSealer().Return(getSealer(), nil) + // don't EXPECT any calls to GetHandshakeSealer and Get1RTTSealer + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) + initialStream.EXPECT().HasData().Return(true).Times(2) + initialStream.EXPECT().PopCryptoFrame(gomock.Any()).DoAndReturn(func(s protocol.ByteCount) *wire.CryptoFrame { + f := &wire.CryptoFrame{Offset: 0x1337} + f.Data = bytes.Repeat([]byte{'f'}, int(s-f.Length(packer.version)-1)) + Expect(f.Length(packer.version)).To(Equal(s)) + return f + }) + p, err := packer.PackCoalescedPacket(size) + Expect(err).ToNot(HaveOccurred()) + Expect(p).ToNot(BeNil()) + Expect(len(p.buffer.Data)).To(Equal(size)) + }) + + It("packs a small packet, that includes a 1-RTT packet", func() { + const size = 2 * protocol.MinCoalescedPacketSize + pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x24)) + pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x24), protocol.PacketNumberLen2) + pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x24)) + sealingManager.EXPECT().GetInitialSealer().Return(nil, handshake.ErrKeysDropped) + sealingManager.EXPECT().GetHandshakeSealer().Return(getSealer(), nil) + oneRTTSealer := getSealer() + sealingManager.EXPECT().Get1RTTSealer().Return(oneRTTSealer, nil) + ackFramer.EXPECT().GetAckFrame(protocol.EncryptionHandshake) + handshakeStream.EXPECT().HasData().Return(true).Times(2) + handshakeStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(&wire.CryptoFrame{ + Offset: 0x1337, + Data: []byte("foobar"), + }) + expectAppendControlFrames() + var appDataSize protocol.ByteCount + framer.EXPECT().AppendStreamFrames(gomock.Any(), gomock.Any()).DoAndReturn(func(frames []ackhandler.Frame, maxSize protocol.ByteCount) ([]ackhandler.Frame, protocol.ByteCount) { + appDataSize = maxSize + f := &wire.StreamFrame{Data: []byte("foobar")} + return append(frames, ackhandler.Frame{Frame: f}), f.Length(packer.version) + }) + p, err := packer.PackCoalescedPacket(size) + Expect(err).ToNot(HaveOccurred()) + Expect(p).ToNot(BeNil()) + Expect(p.packets).To(HaveLen(2)) + Expect(appDataSize).To(Equal(size - p.packets[0].length - p.packets[1].header.GetLength(packer.version) - protocol.ByteCount(oneRTTSealer.Overhead()))) + }) + It("adds retransmissions", func() { f := &wire.CryptoFrame{Data: []byte("Initial")} retransmissionQueue.AddInitial(f) @@ -954,7 +1010,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) initialStream.EXPECT().HasData() - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) @@ -972,7 +1028,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionInitial).Return(protocol.PacketNumber(0x42)) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].ack).To(Equal(ack)) @@ -984,7 +1040,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) initialStream.EXPECT().HasData() ackFramer.EXPECT().GetAckFrame(protocol.EncryptionInitial) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(err).ToNot(HaveOccurred()) Expect(p).To(BeNil()) }) @@ -1000,7 +1056,7 @@ var _ = Describe("Packet packer", func() { sealingManager.EXPECT().Get1RTTSealer().Return(nil, handshake.ErrKeysNotYetAvailable) pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42)) - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].ack).To(Equal(ack)) @@ -1020,7 +1076,7 @@ var _ = Describe("Packet packer", func() { initialStream.EXPECT().HasData().Return(true).Times(2) initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.perspective = protocol.PerspectiveClient - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(err).ToNot(HaveOccurred()) Expect(p.buffer.Len()).To(BeEquivalentTo(protocol.MinInitialPacketSize)) Expect(p.packets).To(HaveLen(1)) @@ -1044,7 +1100,7 @@ var _ = Describe("Packet packer", func() { initialStream.EXPECT().PopCryptoFrame(gomock.Any()).Return(f) packer.version = protocol.VersionTLS packer.perspective = protocol.PerspectiveClient - p, err := packer.PackCoalescedPacket() + p, err := packer.PackCoalescedPacket(protocol.MaxByteCount) Expect(err).ToNot(HaveOccurred()) Expect(p.packets).To(HaveLen(1)) Expect(p.packets[0].ack).To(Equal(ack)) diff --git a/session.go b/session.go index 040bc12a6..9627dcd19 100644 --- a/session.go +++ b/session.go @@ -1427,7 +1427,7 @@ func (s *session) sendPacket() (bool, error) { if !s.handshakeConfirmed { now := time.Now() - packet, err := s.packer.PackCoalescedPacket() + packet, err := s.packer.PackCoalescedPacket(protocol.MaxByteCount) if err != nil || packet == nil { return false, err } diff --git a/session_test.go b/session_test.go index 86a5afbbc..52b88659e 100644 --- a/session_test.go +++ b/session_test.go @@ -1369,7 +1369,7 @@ var _ = Describe("Session", func() { sess.sentPacketHandler = sph buffer := getPacketBuffer() buffer.Data = append(buffer.Data, []byte("foobar")...) - packer.EXPECT().PackCoalescedPacket().Return(&coalescedPacket{ + packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).Return(&coalescedPacket{ buffer: buffer, packets: []*packetContents{ { @@ -1394,7 +1394,7 @@ var _ = Describe("Session", func() { }, }, }, nil) - packer.EXPECT().PackCoalescedPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() @@ -1445,7 +1445,7 @@ var _ = Describe("Session", func() { }) It("cancels the HandshakeComplete context and informs the SentPacketHandler when the handshake completes", func() { - packer.EXPECT().PackCoalescedPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes() finishHandshake := make(chan struct{}) sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sess.sentPacketHandler = sph @@ -1482,7 +1482,7 @@ var _ = Describe("Session", func() { It("sends a session ticket when the handshake completes", func() { const size = protocol.MaxPostHandshakeCryptoFrameSize * 3 / 2 - packer.EXPECT().PackCoalescedPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes() finishHandshake := make(chan struct{}) sessionRunner.EXPECT().Retire(clientDestConnID) go func() { @@ -1525,7 +1525,7 @@ var _ = Describe("Session", func() { }) It("doesn't cancel the HandshakeComplete context when the handshake fails", func() { - packer.EXPECT().PackCoalescedPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes() streamManager.EXPECT().CloseWithError(gomock.Any()) expectReplaceWithClosed() packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) @@ -1547,7 +1547,7 @@ var _ = Describe("Session", func() { It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { done := make(chan struct{}) sessionRunner.EXPECT().Retire(clientDestConnID) - packer.EXPECT().PackCoalescedPacket().DoAndReturn(func() (*packedPacket, error) { + packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).DoAndReturn(func(protocol.ByteCount) (*packedPacket, error) { frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) Expect(frames).ToNot(BeEmpty()) Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) @@ -1559,7 +1559,7 @@ var _ = Describe("Session", func() { buffer: getPacketBuffer(), }, nil }) - packer.EXPECT().PackCoalescedPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes() go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() @@ -1630,7 +1630,7 @@ var _ = Describe("Session", func() { } streamManager.EXPECT().UpdateLimits(params) packer.EXPECT().HandleTransportParameters(params) - packer.EXPECT().PackCoalescedPacket().MaxTimes(3) + packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).MaxTimes(3) Expect(sess.earlySessionReady()).ToNot(BeClosed()) sessionRunner.EXPECT().GetStatelessResetToken(gomock.Any()).Times(2) sessionRunner.EXPECT().Add(gomock.Any(), sess).Times(2) @@ -1677,7 +1677,7 @@ var _ = Describe("Session", func() { setRemoteIdleTimeout(5 * time.Second) sess.lastPacketReceivedTime = time.Now().Add(-5 * time.Second / 2) sent := make(chan struct{}) - packer.EXPECT().PackCoalescedPacket().Do(func() (*packedPacket, error) { + packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).Do(func(protocol.ByteCount) (*packedPacket, error) { close(sent) return nil, nil }) @@ -1690,7 +1690,7 @@ var _ = Describe("Session", func() { setRemoteIdleTimeout(time.Hour) sess.lastPacketReceivedTime = time.Now().Add(-protocol.MaxKeepAliveInterval).Add(-time.Millisecond) sent := make(chan struct{}) - packer.EXPECT().PackCoalescedPacket().Do(func() (*packedPacket, error) { + packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).Do(func(protocol.ByteCount) (*packedPacket, error) { close(sent) return nil, nil }) @@ -1794,7 +1794,7 @@ var _ = Describe("Session", func() { }) It("closes the session due to the idle timeout after handshake", func() { - packer.EXPECT().PackCoalescedPacket().AnyTimes() + packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes() gomock.InOrder( sessionRunner.EXPECT().Retire(clientDestConnID), sessionRunner.EXPECT().Remove(gomock.Any()), @@ -2180,7 +2180,7 @@ var _ = Describe("Client Session", func() { }, } packer.EXPECT().HandleTransportParameters(gomock.Any()) - packer.EXPECT().PackCoalescedPacket().MaxTimes(1) + packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).MaxTimes(1) qlogger.EXPECT().ReceivedTransportParameters(params) sess.processTransportParameters(params) // make sure the connection ID is not retired From 60a918a108b0ab74b197326698c01896412715e2 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 23 Apr 2020 13:34:31 +0700 Subject: [PATCH 2/4] limit available window to 3x of received bytes before address validation --- internal/ackhandler/interfaces.go | 3 + .../mock_sent_packet_tracker_test.go | 12 ++++ .../ackhandler/received_packet_handler.go | 1 + .../received_packet_handler_test.go | 12 ++++ internal/ackhandler/sent_packet_handler.go | 38 +++++++++- .../ackhandler/sent_packet_handler_test.go | 69 +++++++++++++++++-- .../mocks/ackhandler/sent_packet_handler.go | 26 +++++++ session.go | 3 +- session_test.go | 35 ++++++++-- 9 files changed, 186 insertions(+), 13 deletions(-) diff --git a/internal/ackhandler/interfaces.go b/internal/ackhandler/interfaces.go index 9fcead66e..207b95a72 100644 --- a/internal/ackhandler/interfaces.go +++ b/internal/ackhandler/interfaces.go @@ -25,12 +25,14 @@ type SentPacketHandler interface { // SentPacket may modify the packet SentPacket(packet *Packet) ReceivedAck(ackFrame *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) error + ReceivedBytes(protocol.ByteCount) DropPackets(protocol.EncryptionLevel) ResetForRetry() error SetHandshakeComplete() // The SendMode determines if and what kind of packets can be sent. SendMode() SendMode + AmplificationWindow() protocol.ByteCount // TimeUntilSend is the time when the next packet should be sent. // It is used for pacing packets. TimeUntilSend() time.Time @@ -56,6 +58,7 @@ type SentPacketHandler interface { type sentPacketTracker interface { GetLowestPacketNotConfirmedAcked() protocol.PacketNumber + ReceivedPacket(protocol.EncryptionLevel) } // ReceivedPacketHandler handles ACKs needed to send for incoming packets diff --git a/internal/ackhandler/mock_sent_packet_tracker_test.go b/internal/ackhandler/mock_sent_packet_tracker_test.go index 9ababef4a..9c3e8b8c5 100644 --- a/internal/ackhandler/mock_sent_packet_tracker_test.go +++ b/internal/ackhandler/mock_sent_packet_tracker_test.go @@ -47,3 +47,15 @@ func (mr *MockSentPacketTrackerMockRecorder) GetLowestPacketNotConfirmedAcked() mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLowestPacketNotConfirmedAcked", reflect.TypeOf((*MockSentPacketTracker)(nil).GetLowestPacketNotConfirmedAcked)) } + +// ReceivedPacket mocks base method +func (m *MockSentPacketTracker) ReceivedPacket(arg0 protocol.EncryptionLevel) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedPacket", arg0) +} + +// ReceivedPacket indicates an expected call of ReceivedPacket +func (mr *MockSentPacketTrackerMockRecorder) ReceivedPacket(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockSentPacketTracker)(nil).ReceivedPacket), arg0) +} diff --git a/internal/ackhandler/received_packet_handler.go b/internal/ackhandler/received_packet_handler.go index 367d53553..b4e5f0b98 100644 --- a/internal/ackhandler/received_packet_handler.go +++ b/internal/ackhandler/received_packet_handler.go @@ -64,6 +64,7 @@ func (h *receivedPacketHandler) ReceivedPacket( rcvTime time.Time, shouldInstigateAck bool, ) error { + h.sentPackets.ReceivedPacket(encLevel) switch encLevel { case protocol.EncryptionInitial: h.initialPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck) diff --git a/internal/ackhandler/received_packet_handler_test.go b/internal/ackhandler/received_packet_handler_test.go index 985c6ce62..0018c68ab 100644 --- a/internal/ackhandler/received_packet_handler_test.go +++ b/internal/ackhandler/received_packet_handler_test.go @@ -3,6 +3,8 @@ package ackhandler import ( "time" + "github.com/golang/mock/gomock" + "github.com/lucas-clemente/quic-go/internal/congestion" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" @@ -29,6 +31,9 @@ var _ = Describe("Received Packet Handler", func() { It("generates ACKs for different packet number spaces", func() { sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sendTime := time.Now().Add(-time.Second) + sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionInitial).Times(2) + sentPackets.EXPECT().ReceivedPacket(protocol.EncryptionHandshake).Times(2) + sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT).Times(2) Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(5, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) @@ -54,6 +59,8 @@ var _ = Describe("Received Packet Handler", func() { It("uses the same packet number space for 0-RTT and 1-RTT packets", func() { sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() + sentPackets.EXPECT().ReceivedPacket(protocol.Encryption0RTT) + sentPackets.EXPECT().ReceivedPacket(protocol.Encryption1RTT) sendTime := time.Now().Add(-time.Second) Expect(handler.ReceivedPacket(2, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(3, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) @@ -64,6 +71,7 @@ var _ = Describe("Received Packet Handler", func() { }) It("rejects 0-RTT packets with higher packet numbers than 1-RTT packets", func() { + sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(3) sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sendTime := time.Now() Expect(handler.ReceivedPacket(10, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) @@ -72,6 +80,7 @@ var _ = Describe("Received Packet Handler", func() { }) It("allows reordered 0-RTT packets", func() { + sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(3) sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sendTime := time.Now() Expect(handler.ReceivedPacket(10, protocol.Encryption0RTT, sendTime, true)).To(Succeed()) @@ -80,6 +89,7 @@ var _ = Describe("Received Packet Handler", func() { }) It("drops Initial packets", func() { + sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(2) sendTime := time.Now().Add(-time.Second) Expect(handler.ReceivedPacket(2, protocol.EncryptionInitial, sendTime, true)).To(Succeed()) Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) @@ -90,6 +100,7 @@ var _ = Describe("Received Packet Handler", func() { }) It("drops Handshake packets", func() { + sentPackets.EXPECT().ReceivedPacket(gomock.Any()).Times(2) sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().AnyTimes() sendTime := time.Now().Add(-time.Second) Expect(handler.ReceivedPacket(1, protocol.EncryptionHandshake, sendTime, true)).To(Succeed()) @@ -105,6 +116,7 @@ var _ = Describe("Received Packet Handler", func() { }) It("drops old ACK ranges", func() { + sentPackets.EXPECT().ReceivedPacket(gomock.Any()).AnyTimes() sendTime := time.Now() sentPackets.EXPECT().GetLowestPacketNotConfirmedAcked().Times(2) Expect(handler.ReceivedPacket(1, protocol.Encryption1RTT, sendTime, true)).To(Succeed()) diff --git a/internal/ackhandler/sent_packet_handler.go b/internal/ackhandler/sent_packet_handler.go index b3a3c40bc..007c35ce5 100644 --- a/internal/ackhandler/sent_packet_handler.go +++ b/internal/ackhandler/sent_packet_handler.go @@ -20,6 +20,8 @@ const ( timeThreshold = 9.0 / 8 // Maximum reordering in packets before packet threshold loss detection considers a packet lost. packetThreshold = 3 + // Before validating the client's address, the server won't send more than 3x bytes than it received. + amplificationFactor = 3 ) type packetNumberSpace struct { @@ -49,8 +51,16 @@ type sentPacketHandler struct { handshakePackets *packetNumberSpace appDataPackets *packetNumberSpace + // Do we know that the peer completed address validation yet? + // Always true for the server. peerCompletedAddressValidation bool - handshakeComplete bool + bytesReceived protocol.ByteCount + bytesSent protocol.ByteCount + // Have we validated the peer's address yet? + // Always true for the client. + peerAddressValidated bool + + handshakeComplete bool // lowestNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived // example: we send an ACK for packets 90-100 with packet number 20 @@ -99,6 +109,7 @@ func newSentPacketHandler( return &sentPacketHandler{ peerCompletedAddressValidation: pers == protocol.PerspectiveServer, + peerAddressValidated: pers == protocol.PerspectiveClient, initialPackets: newPacketNumberSpace(initialPacketNumber), handshakePackets: newPacketNumberSpace(0), appDataPackets: newPacketNumberSpace(0), @@ -168,6 +179,16 @@ func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) { h.ptoMode = SendNone } +func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) { + h.bytesReceived += n +} + +func (h *sentPacketHandler) ReceivedPacket(encLevel protocol.EncryptionLevel) { + if h.perspective == protocol.PerspectiveServer && encLevel == protocol.EncryptionHandshake { + h.peerAddressValidated = true + } +} + func (h *sentPacketHandler) packetsInFlight() int { packetsInFlight := h.appDataPackets.history.Len() if h.handshakePackets != nil { @@ -180,6 +201,7 @@ func (h *sentPacketHandler) packetsInFlight() int { } func (h *sentPacketHandler) SentPacket(packet *Packet) { + h.bytesSent += packet.Length // For the client, drop the Initial packet number space when the first Handshake packet is sent. if h.perspective == protocol.PerspectiveClient && packet.EncryptionLevel == protocol.EncryptionHandshake && h.initialPackets != nil { h.dropPackets(protocol.EncryptionInitial) @@ -638,6 +660,10 @@ func (h *sentPacketHandler) SendMode() SendMode { numTrackedPackets += h.handshakePackets.history.Len() } + if h.AmplificationWindow() == 0 { + h.logger.Debugf("Amplification window limited. Received %d bytes, already sent out %d bytes", h.bytesReceived, h.bytesSent) + return SendNone + } // Don't send any packets if we're keeping track of the maximum number of packets. // Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets, // we will stop sending out new data when reaching MaxOutstandingSentPackets, @@ -683,6 +709,16 @@ func (h *sentPacketHandler) ShouldSendNumPackets() int { return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay))) } +func (h *sentPacketHandler) AmplificationWindow() protocol.ByteCount { + if h.peerAddressValidated { + return protocol.MaxByteCount + } + if h.bytesSent >= amplificationFactor*h.bytesReceived { + return 0 + } + return amplificationFactor*h.bytesReceived - h.bytesSent +} + func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) bool { pnSpace := h.getPacketNumberSpace(encLevel) p := pnSpace.history.FirstOutstanding() diff --git a/internal/ackhandler/sent_packet_handler_test.go b/internal/ackhandler/sent_packet_handler_test.go index e10bbb227..fb6df3181 100644 --- a/internal/ackhandler/sent_packet_handler_test.go +++ b/internal/ackhandler/sent_packet_handler_test.go @@ -495,13 +495,51 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.ReceivedAck(ack, protocol.Encryption1RTT, time.Now())).To(Succeed()) }) - It("passes the bytes in flight to CanSend", func() { - handler.bytesInFlight = 42 - cong.EXPECT().CanSend(protocol.ByteCount(42)) + It("passes the bytes in flight to the congestion controller", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) + cong.EXPECT().OnPacketSent(gomock.Any(), protocol.ByteCount(42), gomock.Any(), protocol.ByteCount(42), true) + cong.EXPECT().TimeUntilSend(gomock.Any()) + handler.SentPacket(&Packet{ + Length: 42, + EncryptionLevel: protocol.EncryptionInitial, + Frames: []Frame{{Frame: &wire.PingFrame{}}}, + SendTime: time.Now(), + }) + cong.EXPECT().CanSend(protocol.ByteCount(42)).Return(true) handler.SendMode() }) + It("returns SendNone if limited by the 3x limit", func() { + handler.ReceivedBytes(100) + cong.EXPECT().OnPacketSent(gomock.Any(), protocol.ByteCount(300), gomock.Any(), protocol.ByteCount(300), true) + cong.EXPECT().TimeUntilSend(gomock.Any()) + handler.SentPacket(&Packet{ + Length: 300, + EncryptionLevel: protocol.EncryptionInitial, + Frames: []Frame{{Frame: &wire.PingFrame{}}}, + SendTime: time.Now(), + }) + cong.EXPECT().CanSend(protocol.ByteCount(300)).Return(true).AnyTimes() + Expect(handler.AmplificationWindow()).To(BeZero()) + Expect(handler.SendMode()).To(Equal(SendNone)) + }) + + It("limits the window to 3x the bytes received, to avoid amplification attacks", func() { + handler.ReceivedPacket(protocol.EncryptionInitial) // receiving an Initial packet doesn't validate the client's address + cong.EXPECT().OnPacketSent(gomock.Any(), protocol.ByteCount(50), gomock.Any(), protocol.ByteCount(50), true) + cong.EXPECT().TimeUntilSend(gomock.Any()) + handler.SentPacket(&Packet{ + Length: 50, + EncryptionLevel: protocol.EncryptionInitial, + Frames: []Frame{{Frame: &wire.PingFrame{}}}, + SendTime: time.Now(), + }) + handler.ReceivedBytes(100) + Expect(handler.AmplificationWindow()).To(Equal(protocol.ByteCount(3*100 - 50))) + }) + It("allows sending of ACKs when congestion limited", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) cong.EXPECT().CanSend(gomock.Any()).Return(true) Expect(handler.SendMode()).To(Equal(SendAny)) cong.EXPECT().CanSend(gomock.Any()).Return(false) @@ -509,6 +547,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("allows sending of ACKs when we're keeping track of MaxOutstandingSentPackets packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) cong.EXPECT().CanSend(gomock.Any()).Return(true).AnyTimes() cong.EXPECT().TimeUntilSend(gomock.Any()).AnyTimes() cong.EXPECT().OnPacketSent(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() @@ -521,6 +560,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("allows PTOs, even when congestion limited", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) // note that we don't EXPECT a call to GetCongestionWindow // that means retransmissions are sent without considering the congestion window handler.numProbesToSend = 1 @@ -561,6 +601,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("doesn't set an alarm if there are no outstanding packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 10})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 11})) ack := &wire.AckFrame{AckRanges: []wire.AckRange{{Smallest: 10, Largest: 11}}} @@ -569,6 +610,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("does nothing on OnAlarm if there are no outstanding packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) Expect(handler.SendMode()).To(Equal(SendAny)) }) @@ -602,6 +644,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("reset the PTO count when receiving an ACK", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) now := time.Now() handler.SetHandshakeComplete() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) @@ -615,6 +658,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("resets the PTO mode and PTO count when a packet number space is dropped", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) now := time.Now() handler.SentPacket(ackElicitingPacket(&Packet{ PacketNumber: 1, @@ -638,6 +682,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("allows two 1-RTT PTOs", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SetHandshakeComplete() var lostPackets []protocol.PacketNumber handler.SentPacket(ackElicitingPacket(&Packet{ @@ -657,6 +702,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("only counts ack-eliciting packets as probe packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SetHandshakeComplete() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) @@ -672,7 +718,8 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.SendMode()).ToNot(Equal(SendPTOAppData)) }) - It("gets two probe packets if RTO expires", func() { + It("gets two probe packets if PTO expires", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SetHandshakeComplete() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 2})) @@ -698,6 +745,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("gets two probe packets if PTO expires, for Handshake packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SentPacket(initialPacket(&Packet{PacketNumber: 1})) handler.SentPacket(initialPacket(&Packet{PacketNumber: 2})) @@ -714,6 +762,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("doesn't send 1-RTT probe packets before the handshake completes", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1})) updateRTT(time.Hour) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) // TLP @@ -726,6 +775,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("resets the send mode when it receives an acknowledgement after queueing probe packets", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SetHandshakeComplete() handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 1, SendTime: time.Now().Add(-time.Hour)})) handler.rttStats.UpdateRTT(time.Second, 0, time.Now()) @@ -737,6 +787,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("handles ACKs for the original packet", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) handler.SentPacket(ackElicitingPacket(&Packet{PacketNumber: 5, SendTime: time.Now().Add(-time.Hour)})) handler.rttStats.UpdateRTT(time.Second, 0, time.Now()) Expect(handler.OnLossDetectionTimeout()).To(Succeed()) @@ -993,6 +1044,7 @@ var _ = Describe("SentPacketHandler", func() { }) It("cancels the PTO when dropping a packet number space", func() { + handler.ReceivedPacket(protocol.EncryptionHandshake) now := time.Now() handler.SentPacket(handshakePacket(&Packet{PacketNumber: 1, SendTime: now.Add(-time.Minute)})) handler.SentPacket(handshakePacket(&Packet{PacketNumber: 2, SendTime: now.Add(-time.Minute)})) @@ -1028,12 +1080,15 @@ var _ = Describe("SentPacketHandler", func() { }) }) - Context("resetting for retry", func() { + Context("for the client", func() { BeforeEach(func() { perspective = protocol.PerspectiveClient }) - It("queues outstanding packets for retransmission, cancels alarms and resets PTO count", func() { + It("considers the server's address validated right away", func() { + }) + + It("queues outstanding packets for retransmission, cancels alarms and resets PTO count when receiving a Retry", func() { handler.SentPacket(initialPacket(&Packet{PacketNumber: 42})) Expect(handler.GetLossDetectionTimeout()).ToNot(BeZero()) Expect(handler.bytesInFlight).ToNot(BeZero()) @@ -1047,7 +1102,7 @@ var _ = Describe("SentPacketHandler", func() { Expect(handler.ptoCount).To(BeZero()) }) - It("queues outstanding frames for retransmission and cancels alarms", func() { + It("queues outstanding frames for retransmission and cancels alarms when receiving a Retry", func() { var lostInitial, lost0RTT bool handler.SentPacket(&Packet{ PacketNumber: 13, diff --git a/internal/mocks/ackhandler/sent_packet_handler.go b/internal/mocks/ackhandler/sent_packet_handler.go index 07e577473..557576049 100644 --- a/internal/mocks/ackhandler/sent_packet_handler.go +++ b/internal/mocks/ackhandler/sent_packet_handler.go @@ -38,6 +38,20 @@ func (m *MockSentPacketHandler) EXPECT() *MockSentPacketHandlerMockRecorder { return m.recorder } +// AmplificationWindow mocks base method +func (m *MockSentPacketHandler) AmplificationWindow() protocol.ByteCount { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AmplificationWindow") + ret0, _ := ret[0].(protocol.ByteCount) + return ret0 +} + +// AmplificationWindow indicates an expected call of AmplificationWindow +func (mr *MockSentPacketHandlerMockRecorder) AmplificationWindow() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AmplificationWindow", reflect.TypeOf((*MockSentPacketHandler)(nil).AmplificationWindow)) +} + // DropPackets mocks base method func (m *MockSentPacketHandler) DropPackets(arg0 protocol.EncryptionLevel) { m.ctrl.T.Helper() @@ -149,6 +163,18 @@ func (mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2 interf return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2) } +// ReceivedBytes mocks base method +func (m *MockSentPacketHandler) ReceivedBytes(arg0 protocol.ByteCount) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReceivedBytes", arg0) +} + +// ReceivedBytes indicates an expected call of ReceivedBytes +func (mr *MockSentPacketHandlerMockRecorder) ReceivedBytes(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedBytes", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedBytes), arg0) +} + // ResetForRetry mocks base method func (m *MockSentPacketHandler) ResetForRetry() error { m.ctrl.T.Helper() diff --git a/session.go b/session.go index 9627dcd19..862d8cbf4 100644 --- a/session.go +++ b/session.go @@ -700,6 +700,7 @@ func (s *session) handlePacketImpl(rp *receivedPacket) bool { var processed bool data := rp.data p := rp + s.sentPacketHandler.ReceivedBytes(protocol.ByteCount(len(data))) for len(data) > 0 { if counter > 0 { p = p.Clone() @@ -1427,7 +1428,7 @@ func (s *session) sendPacket() (bool, error) { if !s.handshakeConfirmed { now := time.Now() - packet, err := s.packer.PackCoalescedPacket(protocol.MaxByteCount) + packet, err := s.packer.PackCoalescedPacket(s.sentPacketHandler.AmplificationWindow()) if err != nil || packet == nil { return false, err } diff --git a/session_test.go b/session_test.go index 52b88659e..8213af939 100644 --- a/session_test.go +++ b/session_test.go @@ -1033,6 +1033,13 @@ var _ = Describe("Session", func() { It("sends packets", func() { sess.handshakeConfirmed = true + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ShouldSendNumPackets().Return(1000) + sph.EXPECT().SentPacket(gomock.Any()) + sess.sentPacketHandler = sph runSession() p := getPacket(1) packer.EXPECT().PackPacket().Return(p, nil) @@ -1069,6 +1076,13 @@ var _ = Describe("Session", func() { It("adds a BLOCKED frame when it is connection-level flow control blocked", func() { sess.handshakeConfirmed = true + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().ShouldSendNumPackets().Return(1000) + sph.EXPECT().SentPacket(gomock.Any()) + sess.sentPacketHandler = sph fc := mocks.NewMockConnectionFlowController(mockCtrl) fc.EXPECT().IsNewlyBlocked().Return(true, protocol.ByteCount(1337)) fc.EXPECT().IsNewlyBlocked() @@ -1366,10 +1380,12 @@ var _ = Describe("Session", func() { It("sends coalesced packets before the handshake is confirmed", func() { sess.handshakeConfirmed = false sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + const window protocol.ByteCount = 321 + sph.EXPECT().AmplificationWindow().Return(window).AnyTimes() sess.sentPacketHandler = sph buffer := getPacketBuffer() buffer.Data = append(buffer.Data, []byte("foobar")...) - packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).Return(&coalescedPacket{ + packer.EXPECT().PackCoalescedPacket(window).Return(&coalescedPacket{ buffer: buffer, packets: []*packetContents{ { @@ -1394,7 +1410,7 @@ var _ = Describe("Session", func() { }, }, }, nil) - packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes() + packer.EXPECT().PackCoalescedPacket(window).AnyTimes() sph.EXPECT().GetLossDetectionTimeout().AnyTimes() sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() @@ -1545,9 +1561,17 @@ var _ = Describe("Session", func() { }) It("sends a HANDSHAKE_DONE frame when the handshake completes", func() { + sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) + sph.EXPECT().SendMode().Return(ackhandler.SendAny).AnyTimes() + sph.EXPECT().AmplificationWindow().Return(protocol.MaxByteCount) + sph.EXPECT().SetHandshakeComplete() + sph.EXPECT().GetLossDetectionTimeout().AnyTimes() + sph.EXPECT().TimeUntilSend().AnyTimes() + sph.EXPECT().ShouldSendNumPackets().AnyTimes().Return(10) + sess.sentPacketHandler = sph done := make(chan struct{}) sessionRunner.EXPECT().Retire(clientDestConnID) - packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).DoAndReturn(func(protocol.ByteCount) (*packedPacket, error) { + packer.EXPECT().PackCoalescedPacket(gomock.Any()).DoAndReturn(func(protocol.ByteCount) (*packedPacket, error) { frames, _ := sess.framer.AppendControlFrames(nil, protocol.MaxByteCount) Expect(frames).ToNot(BeEmpty()) Expect(frames[0].Frame).To(BeEquivalentTo(&wire.HandshakeDoneFrame{})) @@ -1559,7 +1583,7 @@ var _ = Describe("Session", func() { buffer: getPacketBuffer(), }, nil }) - packer.EXPECT().PackCoalescedPacket(protocol.MaxByteCount).AnyTimes() + packer.EXPECT().PackCoalescedPacket(gomock.Any()).AnyTimes() go func() { defer GinkgoRecover() cryptoSetup.EXPECT().RunHandshake() @@ -1659,6 +1683,7 @@ var _ = Describe("Session", func() { BeforeEach(func() { sess.config.MaxIdleTimeout = 30 * time.Second sess.config.KeepAlive = true + sess.receivedPacketHandler.ReceivedPacket(0, protocol.EncryptionHandshake, time.Now(), true) }) AfterEach(func() { @@ -2098,6 +2123,7 @@ var _ = Describe("Client Session", func() { It("handles Retry packets", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sess.sentPacketHandler = sph + sph.EXPECT().ReceivedBytes(gomock.Any()) sph.EXPECT().ResetForRetry() cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}) packer.EXPECT().SetToken([]byte("foobar")) @@ -2333,6 +2359,7 @@ var _ = Describe("Client Session", func() { It("ignores Initial packets which use original source id, after accepting a Retry", func() { sph := mockackhandler.NewMockSentPacketHandler(mockCtrl) sess.sentPacketHandler = sph + sph.EXPECT().ReceivedBytes(gomock.Any()).Times(2) sph.EXPECT().ResetForRetry() newSrcConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} cryptoSetup.EXPECT().ChangeConnectionID(newSrcConnID) From e4f02ff68cc51f3ccffdb10982e822622a9a7cd4 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 11 May 2020 14:19:19 +0700 Subject: [PATCH 3/4] generate a new CA and cert chain for every run of the integration tests --- integrationtests/self/self_suite_test.go | 102 +++++++++++++++++++---- 1 file changed, 85 insertions(+), 17 deletions(-) diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index 4f03170fb..e5e71bac5 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -3,19 +3,23 @@ package self_test import ( "bufio" "bytes" + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "flag" "fmt" "io" "log" - "math/rand" + "math/big" + mrand "math/rand" "os" "sync" "testing" + "time" "github.com/lucas-clemente/quic-go" - - "github.com/lucas-clemente/quic-go/internal/testdata" "github.com/lucas-clemente/quic-go/internal/utils" . "github.com/onsi/ginkgo" @@ -24,19 +28,6 @@ import ( const alpn = "quic-go integration tests" -func getTLSConfig() *tls.Config { - conf := testdata.GetTLSConfig() - conf.NextProtos = []string{alpn} - return conf -} - -func getTLSClientConfig() *tls.Config { - return &tls.Config{ - RootCAs: testdata.GetRootCA(), - NextProtos: []string{alpn}, - } -} - const ( dataLen = 500 * 1024 // 500 KB dataLenLong = 50 * 1024 * 1024 // 50 MB @@ -93,6 +84,11 @@ var ( logBufOnce sync.Once logBuf *syncedBuffer enableQlog bool + + caPrivateKey *rsa.PrivateKey + ca *x509.Certificate + leafPrivateKey *rsa.PrivateKey + leafCert *x509.Certificate ) // read the logfile command line flag @@ -100,6 +96,78 @@ var ( func init() { flag.StringVar(&logFileName, "logfile", "", "log file") flag.BoolVar(&enableQlog, "qlog", false, "enable qlog") + + if err := generateCA(); err != nil { + panic(err) + } + if err := generateCertChain(); err != nil { + panic(err) + } +} + +func generateCA() error { + caCert := &x509.Certificate{ + SerialNumber: big.NewInt(2019), + Subject: pkix.Name{}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + var err error + caPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return err + } + caBytes, err := x509.CreateCertificate(rand.Reader, caCert, caCert, &caPrivateKey.PublicKey, caPrivateKey) + if err != nil { + return err + } + ca, err = x509.ParseCertificate(caBytes) + return err +} + +func generateCertChain() error { + cert := &x509.Certificate{ + SerialNumber: big.NewInt(1), + DNSNames: []string{"localhost"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + var err error + leafPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return err + } + certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &leafPrivateKey.PublicKey, caPrivateKey) + if err != nil { + return err + } + leafCert, err = x509.ParseCertificate(certBytes) + return err +} + +func getTLSConfig() *tls.Config { + return &tls.Config{ + Certificates: []tls.Certificate{tls.Certificate{ + Certificate: [][]byte{leafCert.Raw}, + PrivateKey: leafPrivateKey, + }}, + NextProtos: []string{alpn}, + } +} + +func getTLSClientConfig() *tls.Config { + root := x509.NewCertPool() + root.AddCert(ca) + return &tls.Config{ + RootCAs: root, + NextProtos: []string{alpn}, + } } func getQuicConfigForClient(conf *quic.Config) *quic.Config { @@ -163,5 +231,5 @@ func TestSelf(t *testing.T) { } var _ = BeforeSuite(func() { - rand.Seed(GinkgoRandomSeed()) + mrand.Seed(GinkgoRandomSeed()) }) From e33f7d0fb95d488b6ccb8d675baed840eb34e831 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 11 May 2020 15:58:40 +0700 Subject: [PATCH 4/4] add integration tests using a very long certificate chain This will trigger the amplification protection. --- integrationtests/self/handshake_drop_test.go | 91 ++++++------ integrationtests/self/handshake_test.go | 46 ++++--- integrationtests/self/self_suite_test.go | 138 ++++++++++++++----- 3 files changed, 185 insertions(+), 90 deletions(-) diff --git a/integrationtests/self/handshake_drop_test.go b/integrationtests/self/handshake_drop_test.go index 92e9951ef..df062f93d 100644 --- a/integrationtests/self/handshake_drop_test.go +++ b/integrationtests/self/handshake_drop_test.go @@ -2,6 +2,7 @@ package self_test import ( "context" + "crypto/tls" "fmt" mrand "math/rand" "net" @@ -32,7 +33,7 @@ var _ = Describe("Handshake drop tests", func() { const timeout = 10 * time.Minute - startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, version protocol.VersionNumber) { + startListenerAndProxy := func(dropCallback quicproxy.DropCallback, doRetry bool, longCertChain bool, version protocol.VersionNumber) { conf := getQuicConfigForServer(&quic.Config{ MaxIdleTimeout: timeout, HandshakeTimeout: timeout, @@ -41,8 +42,14 @@ var _ = Describe("Handshake drop tests", func() { if !doRetry { conf.AcceptToken = func(net.Addr, *quic.Token) bool { return true } } + var tlsConf *tls.Config + if longCertChain { + tlsConf = getTLSConfigWithLongCertChain() + } else { + tlsConf = getTLSConfig() + } var err error - ln, err = quic.ListenAddr("localhost:0", getTLSConfig(), conf) + ln, err = quic.ListenAddr("localhost:0", tlsConf, conf) Expect(err).ToNot(HaveOccurred()) serverPort := ln.Addr().(*net.UDPAddr).Port proxy, err = quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ @@ -184,46 +191,52 @@ var _ = Describe("Handshake drop tests", func() { } Context(desc, func() { - for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} { - app := a + for _, lcc := range []bool{false, true} { + longCertChain := lcc - Context(app.name, func() { - It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() { - var incoming, outgoing int32 - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { - var p int32 - switch d { - case quicproxy.DirectionIncoming: - p = atomic.AddInt32(&incoming, 1) - case quicproxy.DirectionOutgoing: - p = atomic.AddInt32(&outgoing, 1) - } - return p == 1 && d.Is(direction) - }, doRetry, version) - app.run(version) - }) + Context(fmt.Sprintf("using a long certificate chain: %t", longCertChain), func() { + for _, a := range []*applicationProtocol{clientSpeaksFirst, serverSpeaksFirst, nobodySpeaks} { + app := a - It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() { - var incoming, outgoing int32 - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { - var p int32 - switch d { - case quicproxy.DirectionIncoming: - p = atomic.AddInt32(&incoming, 1) - case quicproxy.DirectionOutgoing: - p = atomic.AddInt32(&outgoing, 1) - } - return p == 2 && d.Is(direction) - }, doRetry, version) - app.run(version) - }) + Context(app.name, func() { + It(fmt.Sprintf("establishes a connection when the first packet is lost in %s direction", direction), func() { + var incoming, outgoing int32 + startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + var p int32 + switch d { + case quicproxy.DirectionIncoming: + p = atomic.AddInt32(&incoming, 1) + case quicproxy.DirectionOutgoing: + p = atomic.AddInt32(&outgoing, 1) + } + return p == 1 && d.Is(direction) + }, doRetry, longCertChain, version) + app.run(version) + }) - It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() { - startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { - return d.Is(direction) && stochasticDropper(3) - }, doRetry, version) - app.run(version) - }) + It(fmt.Sprintf("establishes a connection when the second packet is lost in %s direction", direction), func() { + var incoming, outgoing int32 + startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + var p int32 + switch d { + case quicproxy.DirectionIncoming: + p = atomic.AddInt32(&incoming, 1) + case quicproxy.DirectionOutgoing: + p = atomic.AddInt32(&outgoing, 1) + } + return p == 2 && d.Is(direction) + }, doRetry, longCertChain, version) + app.run(version) + }) + + It(fmt.Sprintf("establishes a connection when 1/3 of the packets are lost in %s direction", direction), func() { + startListenerAndProxy(func(d quicproxy.Direction, _ []byte) bool { + return d.Is(direction) && stochasticDropper(3) + }, doRetry, longCertChain, version) + app.run(version) + }) + }) + } }) } }) diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index ff5c9455a..e86e25958 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -51,14 +51,12 @@ var _ = Describe("Handshake tests", func() { server quic.Listener serverConfig *quic.Config acceptStopped chan struct{} - tlsServerConf *tls.Config ) BeforeEach(func() { server = nil acceptStopped = make(chan struct{}) serverConfig = getQuicConfigForServer(nil) - tlsServerConf = getTLSConfig() }) AfterEach(func() { @@ -68,10 +66,10 @@ var _ = Describe("Handshake tests", func() { } }) - runServer := func() quic.Listener { + runServer := func(tlsConf *tls.Config) { var err error // start the server - server, err = quic.ListenAddr("localhost:0", tlsServerConf, serverConfig) + server, err = quic.ListenAddr("localhost:0", tlsConf, serverConfig) Expect(err).ToNot(HaveOccurred()) go func() { @@ -83,7 +81,6 @@ var _ = Describe("Handshake tests", func() { } } }() - return server } if !israce.Enabled { @@ -103,7 +100,7 @@ var _ = Describe("Handshake tests", func() { // the server doesn't support the highest supported version, which is the first one the client will try // but it supports a bunch of versions that the client doesn't speak serverConfig.Versions = []protocol.VersionNumber{7, 8, protocol.SupportedVersions[0], 9} - server := runServer() + runServer(getTLSConfig()) defer server.Close() sess, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), @@ -119,7 +116,7 @@ var _ = Describe("Handshake tests", func() { // the server doesn't support the highest supported version, which is the first one the client will try // but it supports a bunch of versions that the client doesn't speak serverConfig.Versions = supportedVersions - server := runServer() + runServer(getTLSConfig()) defer server.Close() sess, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), @@ -145,9 +142,11 @@ var _ = Describe("Handshake tests", func() { suiteID := id It(fmt.Sprintf("using %s", name), func() { - tlsServerConf.CipherSuites = []uint16{suiteID} - ln, err := quic.ListenAddr("localhost:0", tlsServerConf, serverConfig) + tlsConf := getTLSConfig() + tlsConf.CipherSuites = []uint16{suiteID} + ln, err := quic.ListenAddr("localhost:0", tlsConf, serverConfig) Expect(err).ToNot(HaveOccurred()) + defer ln.Close() go func() { defer GinkgoRecover() @@ -177,7 +176,7 @@ var _ = Describe("Handshake tests", func() { } }) - Context("Certifiate validation", func() { + Context("Certificate validation", func() { for _, v := range protocol.SupportedVersions { version := v @@ -189,11 +188,8 @@ var _ = Describe("Handshake tests", func() { clientConfig = getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}) }) - JustBeforeEach(func() { - runServer() - }) - It("accepts the certificate", func() { + runServer(getTLSConfig()) _, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), @@ -202,7 +198,18 @@ var _ = Describe("Handshake tests", func() { Expect(err).ToNot(HaveOccurred()) }) + It("works with a long certificate chain", func() { + runServer(getTLSConfigWithLongCertChain()) + _, err := quic.DialAddr( + fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), + getTLSClientConfig(), + getQuicConfigForClient(&quic.Config{Versions: []protocol.VersionNumber{version}}), + ) + Expect(err).ToNot(HaveOccurred()) + }) + It("errors if the server name doesn't match", func() { + runServer(getTLSConfig()) _, err := quic.DialAddr( fmt.Sprintf("127.0.0.1:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), @@ -212,7 +219,10 @@ var _ = Describe("Handshake tests", func() { }) It("fails the handshake if the client fails to provide the requested client cert", func() { - tlsServerConf.ClientAuth = tls.RequireAndVerifyClientCert + tlsConf := getTLSConfig() + tlsConf.ClientAuth = tls.RequireAndVerifyClientCert + runServer(tlsConf) + sess, err := quic.DialAddr( fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port), getTLSClientConfig(), @@ -234,6 +244,7 @@ var _ = Describe("Handshake tests", func() { }) It("uses the ServerName in the tls.Config", func() { + runServer(getTLSConfig()) tlsConf := getTLSClientConfig() tlsConf.ServerName = "localhost" _, err := quic.DialAddr( @@ -350,7 +361,7 @@ var _ = Describe("Handshake tests", func() { Context("ALPN", func() { It("negotiates an application protocol", func() { - ln, err := quic.ListenAddr("localhost:0", tlsServerConf, serverConfig) + ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) done := make(chan struct{}) @@ -379,7 +390,7 @@ var _ = Describe("Handshake tests", func() { }) It("errors if application protocol negotiation fails", func() { - server := runServer() + runServer(getTLSConfig()) tlsConf := getTLSClientConfig() tlsConf.NextProtos = []string{"foobar"} @@ -391,7 +402,6 @@ var _ = Describe("Handshake tests", func() { Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("CRYPTO_ERROR")) Expect(err.Error()).To(ContainSubstring("no application protocol")) - Expect(server.Close()).To(Succeed()) }) }) diff --git a/integrationtests/self/self_suite_test.go b/integrationtests/self/self_suite_test.go index e5e71bac5..cd0821476 100644 --- a/integrationtests/self/self_suite_test.go +++ b/integrationtests/self/self_suite_test.go @@ -85,10 +85,9 @@ var ( logBuf *syncedBuffer enableQlog bool - caPrivateKey *rsa.PrivateKey - ca *x509.Certificate - leafPrivateKey *rsa.PrivateKey - leafCert *x509.Certificate + tlsConfig *tls.Config + tlsConfigLongChain *tls.Config + tlsClientConfig *tls.Config ) // read the logfile command line flag @@ -97,16 +96,37 @@ func init() { flag.StringVar(&logFileName, "logfile", "", "log file") flag.BoolVar(&enableQlog, "qlog", false, "enable qlog") - if err := generateCA(); err != nil { + ca, caPrivateKey, err := generateCA() + if err != nil { panic(err) } - if err := generateCertChain(); err != nil { + leafCert, leafPrivateKey, err := generateLeafCert(ca, caPrivateKey) + if err != nil { panic(err) } + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{tls.Certificate{ + Certificate: [][]byte{leafCert.Raw}, + PrivateKey: leafPrivateKey, + }}, + NextProtos: []string{alpn}, + } + tlsConfLongChain, err := generateTLSConfigWithLongCertChain(ca, caPrivateKey) + if err != nil { + panic(err) + } + tlsConfigLongChain = tlsConfLongChain + + root := x509.NewCertPool() + root.AddCert(ca) + tlsClientConfig = &tls.Config{ + RootCAs: root, + NextProtos: []string{alpn}, + } } -func generateCA() error { - caCert := &x509.Certificate{ +func generateCA() (*x509.Certificate, *rsa.PrivateKey, error) { + certTempl := &x509.Certificate{ SerialNumber: big.NewInt(2019), Subject: pkix.Name{}, NotBefore: time.Now(), @@ -116,21 +136,23 @@ func generateCA() error { KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, } - var err error - caPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048) + caPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return err + return nil, nil, err } - caBytes, err := x509.CreateCertificate(rand.Reader, caCert, caCert, &caPrivateKey.PublicKey, caPrivateKey) + caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &caPrivateKey.PublicKey, caPrivateKey) if err != nil { - return err + return nil, nil, err } - ca, err = x509.ParseCertificate(caBytes) - return err + ca, err := x509.ParseCertificate(caBytes) + if err != nil { + return nil, nil, err + } + return ca, caPrivateKey, nil } -func generateCertChain() error { - cert := &x509.Certificate{ +func generateLeafCert(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey) (*x509.Certificate, *rsa.PrivateKey, error) { + certTempl := &x509.Certificate{ SerialNumber: big.NewInt(1), DNSNames: []string{"localhost"}, NotBefore: time.Now(), @@ -138,36 +160,86 @@ func generateCertChain() error { ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature, } - var err error - leafPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048) + privKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return err + return nil, nil, err } - certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &leafPrivateKey.PublicKey, caPrivateKey) + certBytes, err := x509.CreateCertificate(rand.Reader, certTempl, ca, &privKey.PublicKey, caPrivateKey) if err != nil { - return err + return nil, nil, err } - leafCert, err = x509.ParseCertificate(certBytes) - return err + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, nil, err + } + return cert, privKey, nil } -func getTLSConfig() *tls.Config { +// getTLSConfigWithLongCertChain generates a tls.Config that uses a long certificate chain. +// The Root CA used is the same as for the config returned from getTLSConfig(). +func generateTLSConfigWithLongCertChain(ca *x509.Certificate, caPrivateKey *rsa.PrivateKey) (*tls.Config, error) { + const chainLen = 7 + certTempl := &x509.Certificate{ + SerialNumber: big.NewInt(2019), + Subject: pkix.Name{}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + lastCA := ca + lastCAPrivKey := caPrivateKey + privKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + certs := make([]*x509.Certificate, chainLen) + for i := 0; i < chainLen; i++ { + caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, lastCA, &privKey.PublicKey, lastCAPrivKey) + if err != nil { + return nil, err + } + ca, err := x509.ParseCertificate(caBytes) + if err != nil { + return nil, err + } + certs[i] = ca + lastCA = ca + lastCAPrivKey = privKey + } + leafCert, leafPrivateKey, err := generateLeafCert(lastCA, lastCAPrivKey) + if err != nil { + return nil, err + } + + rawCerts := make([][]byte, chainLen+1) + for i, cert := range certs { + rawCerts[chainLen-i] = cert.Raw + } + rawCerts[0] = leafCert.Raw + return &tls.Config{ Certificates: []tls.Certificate{tls.Certificate{ - Certificate: [][]byte{leafCert.Raw}, + Certificate: rawCerts, PrivateKey: leafPrivateKey, }}, NextProtos: []string{alpn}, - } + }, nil +} + +func getTLSConfig() *tls.Config { + return tlsConfig.Clone() +} + +func getTLSConfigWithLongCertChain() *tls.Config { + return tlsConfigLongChain.Clone() } func getTLSClientConfig() *tls.Config { - root := x509.NewCertPool() - root.AddCert(ca) - return &tls.Config{ - RootCAs: root, - NextProtos: []string{alpn}, - } + return tlsClientConfig.Clone() } func getQuicConfigForClient(conf *quic.Config) *quic.Config {