use a gomock SentPacketHandler in the session tests

This commit is contained in:
Marten Seemann
2017-11-25 10:39:46 -08:00
parent 975e45d2e1
commit 5c7fb54445
3 changed files with 244 additions and 122 deletions

View File

@@ -3,6 +3,7 @@ package mocks
//go:generate sh -c "./mockgen_internal.sh mockhandshake handshake/mint_tls.go github.com/lucas-clemente/quic-go/internal/handshake MintTLS"
//go:generate sh -c "./mockgen_internal.sh mocks tls_extension_handler.go github.com/lucas-clemente/quic-go/internal/handshake TLSExtensionHandler"
//go:generate sh -c "./mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController"
//go:generate sh -c "./mockgen_internal.sh mocks sent_packet_handler.go github.com/lucas-clemente/quic-go/ackhandler SentPacketHandler"
//go:generate sh -c "./mockgen_internal.sh mocks connection_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol ConnectionFlowController"
//go:generate sh -c "./mockgen_internal.sh mockcrypto crypto/aead.go github.com/lucas-clemente/quic-go/internal/crypto AEAD"
//go:generate sh -c "./mockgen_stream.sh mocks stream.go github.com/lucas-clemente/quic-go StreamI"

View File

@@ -0,0 +1,153 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/lucas-clemente/quic-go/ackhandler (interfaces: SentPacketHandler)
package mocks
import (
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
ackhandler "github.com/lucas-clemente/quic-go/ackhandler"
protocol "github.com/lucas-clemente/quic-go/internal/protocol"
wire "github.com/lucas-clemente/quic-go/internal/wire"
)
// MockSentPacketHandler is a mock of SentPacketHandler interface
type MockSentPacketHandler struct {
ctrl *gomock.Controller
recorder *MockSentPacketHandlerMockRecorder
}
// MockSentPacketHandlerMockRecorder is the mock recorder for MockSentPacketHandler
type MockSentPacketHandlerMockRecorder struct {
mock *MockSentPacketHandler
}
// NewMockSentPacketHandler creates a new mock instance
func NewMockSentPacketHandler(ctrl *gomock.Controller) *MockSentPacketHandler {
mock := &MockSentPacketHandler{ctrl: ctrl}
mock.recorder = &MockSentPacketHandlerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (_m *MockSentPacketHandler) EXPECT() *MockSentPacketHandlerMockRecorder {
return _m.recorder
}
// DequeuePacketForRetransmission mocks base method
func (_m *MockSentPacketHandler) DequeuePacketForRetransmission() *ackhandler.Packet {
ret := _m.ctrl.Call(_m, "DequeuePacketForRetransmission")
ret0, _ := ret[0].(*ackhandler.Packet)
return ret0
}
// DequeuePacketForRetransmission indicates an expected call of DequeuePacketForRetransmission
func (_mr *MockSentPacketHandlerMockRecorder) DequeuePacketForRetransmission() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "DequeuePacketForRetransmission", reflect.TypeOf((*MockSentPacketHandler)(nil).DequeuePacketForRetransmission))
}
// GetAlarmTimeout mocks base method
func (_m *MockSentPacketHandler) GetAlarmTimeout() time.Time {
ret := _m.ctrl.Call(_m, "GetAlarmTimeout")
ret0, _ := ret[0].(time.Time)
return ret0
}
// GetAlarmTimeout indicates an expected call of GetAlarmTimeout
func (_mr *MockSentPacketHandlerMockRecorder) GetAlarmTimeout() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetAlarmTimeout", reflect.TypeOf((*MockSentPacketHandler)(nil).GetAlarmTimeout))
}
// GetLeastUnacked mocks base method
func (_m *MockSentPacketHandler) GetLeastUnacked() protocol.PacketNumber {
ret := _m.ctrl.Call(_m, "GetLeastUnacked")
ret0, _ := ret[0].(protocol.PacketNumber)
return ret0
}
// GetLeastUnacked indicates an expected call of GetLeastUnacked
func (_mr *MockSentPacketHandlerMockRecorder) GetLeastUnacked() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetLeastUnacked", reflect.TypeOf((*MockSentPacketHandler)(nil).GetLeastUnacked))
}
// GetStopWaitingFrame mocks base method
func (_m *MockSentPacketHandler) GetStopWaitingFrame(_param0 bool) *wire.StopWaitingFrame {
ret := _m.ctrl.Call(_m, "GetStopWaitingFrame", _param0)
ret0, _ := ret[0].(*wire.StopWaitingFrame)
return ret0
}
// GetStopWaitingFrame indicates an expected call of GetStopWaitingFrame
func (_mr *MockSentPacketHandlerMockRecorder) GetStopWaitingFrame(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "GetStopWaitingFrame", reflect.TypeOf((*MockSentPacketHandler)(nil).GetStopWaitingFrame), arg0)
}
// OnAlarm mocks base method
func (_m *MockSentPacketHandler) OnAlarm() {
_m.ctrl.Call(_m, "OnAlarm")
}
// OnAlarm indicates an expected call of OnAlarm
func (_mr *MockSentPacketHandlerMockRecorder) OnAlarm() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "OnAlarm", reflect.TypeOf((*MockSentPacketHandler)(nil).OnAlarm))
}
// ReceivedAck mocks base method
func (_m *MockSentPacketHandler) ReceivedAck(_param0 *wire.AckFrame, _param1 protocol.PacketNumber, _param2 protocol.EncryptionLevel, _param3 time.Time) error {
ret := _m.ctrl.Call(_m, "ReceivedAck", _param0, _param1, _param2, _param3)
ret0, _ := ret[0].(error)
return ret0
}
// ReceivedAck indicates an expected call of ReceivedAck
func (_mr *MockSentPacketHandlerMockRecorder) ReceivedAck(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ReceivedAck", reflect.TypeOf((*MockSentPacketHandler)(nil).ReceivedAck), arg0, arg1, arg2, arg3)
}
// SendingAllowed mocks base method
func (_m *MockSentPacketHandler) SendingAllowed() bool {
ret := _m.ctrl.Call(_m, "SendingAllowed")
ret0, _ := ret[0].(bool)
return ret0
}
// SendingAllowed indicates an expected call of SendingAllowed
func (_mr *MockSentPacketHandlerMockRecorder) SendingAllowed() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SendingAllowed", reflect.TypeOf((*MockSentPacketHandler)(nil).SendingAllowed))
}
// SentPacket mocks base method
func (_m *MockSentPacketHandler) SentPacket(_param0 *ackhandler.Packet) error {
ret := _m.ctrl.Call(_m, "SentPacket", _param0)
ret0, _ := ret[0].(error)
return ret0
}
// SentPacket indicates an expected call of SentPacket
func (_mr *MockSentPacketHandlerMockRecorder) SentPacket(arg0 interface{}) *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SentPacket", reflect.TypeOf((*MockSentPacketHandler)(nil).SentPacket), arg0)
}
// SetHandshakeComplete mocks base method
func (_m *MockSentPacketHandler) SetHandshakeComplete() {
_m.ctrl.Call(_m, "SetHandshakeComplete")
}
// SetHandshakeComplete indicates an expected call of SetHandshakeComplete
func (_mr *MockSentPacketHandlerMockRecorder) SetHandshakeComplete() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "SetHandshakeComplete", reflect.TypeOf((*MockSentPacketHandler)(nil).SetHandshakeComplete))
}
// ShouldSendRetransmittablePacket mocks base method
func (_m *MockSentPacketHandler) ShouldSendRetransmittablePacket() bool {
ret := _m.ctrl.Call(_m, "ShouldSendRetransmittablePacket")
ret0, _ := ret[0].(bool)
return ret0
}
// ShouldSendRetransmittablePacket indicates an expected call of ShouldSendRetransmittablePacket
func (_mr *MockSentPacketHandlerMockRecorder) ShouldSendRetransmittablePacket() *gomock.Call {
return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "ShouldSendRetransmittablePacket", reflect.TypeOf((*MockSentPacketHandler)(nil).ShouldSendRetransmittablePacket))
}

View File

@@ -70,53 +70,6 @@ func (m *mockUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte
}, nil
}
type mockSentPacketHandler struct {
retransmissionQueue []*ackhandler.Packet
sentPackets []*ackhandler.Packet
congestionLimited bool
requestedStopWaiting bool
shouldSendRetransmittablePacket bool
}
func (h *mockSentPacketHandler) SentPacket(packet *ackhandler.Packet) error {
h.sentPackets = append(h.sentPackets, packet)
return nil
}
func (h *mockSentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error {
return nil
}
func (h *mockSentPacketHandler) SetHandshakeComplete() {}
func (h *mockSentPacketHandler) GetLeastUnacked() protocol.PacketNumber { return 1 }
func (h *mockSentPacketHandler) GetAlarmTimeout() time.Time { panic("not implemented") }
func (h *mockSentPacketHandler) OnAlarm() { panic("not implemented") }
func (h *mockSentPacketHandler) SendingAllowed() bool { return !h.congestionLimited }
func (h *mockSentPacketHandler) ShouldSendRetransmittablePacket() bool {
b := h.shouldSendRetransmittablePacket
h.shouldSendRetransmittablePacket = false
return b
}
func (h *mockSentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
h.requestedStopWaiting = true
return &wire.StopWaitingFrame{LeastUnacked: 0x1337}
}
func (h *mockSentPacketHandler) DequeuePacketForRetransmission() *ackhandler.Packet {
if len(h.retransmissionQueue) > 0 {
packet := h.retransmissionQueue[0]
h.retransmissionQueue = h.retransmissionQueue[1:]
return packet
}
return nil
}
func newMockSentPacketHandler() ackhandler.SentPacketHandler {
return &mockSentPacketHandler{}
}
var _ ackhandler.SentPacketHandler = &mockSentPacketHandler{}
type mockReceivedPacketHandler struct {
nextAckFrame *wire.AckFrame
ackAlarm time.Time
@@ -341,6 +294,18 @@ var _ = Describe("Session", func() {
})
})
Context("handling ACK frames", func() {
It("informs the SentPacketHandler about ACKs", func() {
sph := mocks.NewMockSentPacketHandler(mockCtrl)
sess.sentPacketHandler = sph
sess.lastRcvdPacketNumber = 42
f := &wire.AckFrame{LargestAcked: 3, LowestAcked: 2}
err := sess.handleAckFrame(f, protocol.EncryptionSecure)
Expect(err).ToNot(HaveOccurred())
sph.EXPECT().ReceivedAck(f, protocol.PacketNumber(42), protocol.EncryptionSecure, gomock.Any())
})
})
Context("handling RST_STREAM frames", func() {
It("closes the streams for writing", func() {
str, err := sess.GetOrOpenStream(5)
@@ -803,7 +768,12 @@ var _ = Describe("Session", func() {
})
It("sends ACK frames when congestion limited", func() {
sess.sentPacketHandler = &mockSentPacketHandler{congestionLimited: true}
sph := mocks.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLeastUnacked().AnyTimes()
sph.EXPECT().SendingAllowed().Return(false)
sph.EXPECT().GetStopWaitingFrame(false)
sph.EXPECT().SentPacket(gomock.Any())
sess.sentPacketHandler = sph
sess.packer.packetNumberGenerator.next = 0x1338
packetNumber := protocol.PacketNumber(0x035e)
sess.receivedPacketHandler.ReceivedPacket(packetNumber, true)
@@ -814,12 +784,18 @@ var _ = Describe("Session", func() {
})
It("sends a retransmittable packet when required by the SentPacketHandler", func() {
sess.sentPacketHandler = &mockSentPacketHandler{shouldSendRetransmittablePacket: true}
sess.packer.QueueControlFrame(&wire.AckFrame{LargestAcked: 1000})
sph := mocks.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLeastUnacked().AnyTimes()
sph.EXPECT().SendingAllowed().Return(true)
sph.EXPECT().SendingAllowed().Return(false)
sph.EXPECT().DequeuePacketForRetransmission()
sph.EXPECT().ShouldSendRetransmittablePacket().Return(true)
sph.EXPECT().SentPacket(gomock.Any())
sess.sentPacketHandler = sph
err := sess.sendPacket()
Expect(err).ToNot(HaveOccurred())
Expect(mconn.written).To(HaveLen(1))
Expect(sess.sentPacketHandler.(*mockSentPacketHandler).sentPackets[0].Frames).To(ContainElement(&wire.PingFrame{}))
})
It("sends public reset", func() {
@@ -830,36 +806,48 @@ var _ = Describe("Session", func() {
})
It("informs the SentPacketHandler about sent packets", func() {
sess.sentPacketHandler = newMockSentPacketHandler()
sess.packer.packetNumberGenerator.next = 0x1337 + 9
sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}
f := &wire.StreamFrame{
StreamID: 5,
Data: []byte("foobar"),
}
var sentPacket *ackhandler.Packet
sph := mocks.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLeastUnacked().AnyTimes()
sph.EXPECT().GetStopWaitingFrame(gomock.Any())
sph.EXPECT().SendingAllowed().Return(true)
sph.EXPECT().SendingAllowed().Return(false)
sph.EXPECT().DequeuePacketForRetransmission()
sph.EXPECT().ShouldSendRetransmittablePacket()
sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) {
sentPacket = p
})
sess.sentPacketHandler = sph
sess.packer.packetNumberGenerator.next = 0x1337 + 9
sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}
sess.streamFramer.AddFrameForRetransmission(f)
_, err := sess.GetOrOpenStream(5)
Expect(err).ToNot(HaveOccurred())
err = sess.sendPacket()
Expect(err).NotTo(HaveOccurred())
Expect(mconn.written).To(HaveLen(1))
sentPackets := sess.sentPacketHandler.(*mockSentPacketHandler).sentPackets
Expect(sentPackets).To(HaveLen(1))
Expect(sentPackets[0].Frames).To(ContainElement(f))
Expect(sentPackets[0].EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure))
Expect(mconn.written).To(HaveLen(1))
Expect(sentPackets[0].Length).To(BeEquivalentTo(len(<-mconn.written)))
Expect(sentPacket.PacketNumber).To(Equal(protocol.PacketNumber(0x1337 + 9)))
Expect(sentPacket.Frames).To(ContainElement(f))
Expect(sentPacket.EncryptionLevel).To(Equal(protocol.EncryptionForwardSecure))
Expect(sentPacket.Length).To(BeEquivalentTo(len(<-mconn.written)))
})
})
Context("retransmissions", func() {
var sph *mockSentPacketHandler
var sph *mocks.MockSentPacketHandler
BeforeEach(func() {
// a StopWaitingFrame is added, so make sure the packet number of the new package is higher than the packet number of the retransmitted packet
// a STOP_WAITING frame is added, so make sure the packet number of the new package is higher than the packet number of the retransmitted packet
sess.packer.packetNumberGenerator.next = 0x1337 + 10
sess.packer.hasSentPacket = true // make sure this is not the first packet the packer sends
sph = newMockSentPacketHandler().(*mockSentPacketHandler)
sph = mocks.NewMockSentPacketHandler(mockCtrl)
sph.EXPECT().GetLeastUnacked().AnyTimes()
sph.EXPECT().SendingAllowed().Return(true)
sph.EXPECT().ShouldSendRetransmittablePacket()
sess.sentPacketHandler = sph
sess.packer.cryptoSetup = &mockCryptoSetup{encLevelSeal: protocol.EncryptionForwardSecure}
})
@@ -867,45 +855,36 @@ var _ = Describe("Session", func() {
Context("for handshake packets", func() {
It("retransmits an unencrypted packet", func() {
sf := &wire.StreamFrame{StreamID: 1, Data: []byte("foobar")}
sph.retransmissionQueue = []*ackhandler.Packet{{
Frames: []wire.Frame{sf},
EncryptionLevel: protocol.EncryptionUnencrypted,
}}
var sentPacket *ackhandler.Packet
sph.EXPECT().GetStopWaitingFrame(true).Return(&wire.StopWaitingFrame{LeastUnacked: 0x1337})
sph.EXPECT().DequeuePacketForRetransmission().Return(
&ackhandler.Packet{
Frames: []wire.Frame{sf},
EncryptionLevel: protocol.EncryptionUnencrypted,
})
sph.EXPECT().DequeuePacketForRetransmission()
sph.EXPECT().SentPacket(gomock.Any()).Do(func(p *ackhandler.Packet) {
sentPacket = p
})
err := sess.sendPacket()
Expect(err).ToNot(HaveOccurred())
Expect(mconn.written).To(HaveLen(1))
sentPackets := sph.sentPackets
Expect(sentPackets).To(HaveLen(1))
Expect(sentPackets[0].EncryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
Expect(sentPackets[0].Frames).To(HaveLen(2))
Expect(sentPackets[0].Frames[1]).To(Equal(sf))
swf := sentPackets[0].Frames[0].(*wire.StopWaitingFrame)
Expect(sentPacket.EncryptionLevel).To(Equal(protocol.EncryptionUnencrypted))
Expect(sentPacket.Frames).To(HaveLen(2))
Expect(sentPacket.Frames[1]).To(Equal(sf))
swf := sentPacket.Frames[0].(*wire.StopWaitingFrame)
Expect(swf.LeastUnacked).To(Equal(protocol.PacketNumber(0x1337)))
})
It("retransmit a packet encrypted with the initial encryption", func() {
sf := &wire.StreamFrame{StreamID: 1, Data: []byte("foobar")}
sph.retransmissionQueue = []*ackhandler.Packet{{
Frames: []wire.Frame{sf},
EncryptionLevel: protocol.EncryptionSecure,
}}
err := sess.sendPacket()
Expect(err).ToNot(HaveOccurred())
Expect(mconn.written).To(HaveLen(1))
sentPackets := sph.sentPackets
Expect(sentPackets).To(HaveLen(1))
Expect(sentPackets[0].EncryptionLevel).To(Equal(protocol.EncryptionSecure))
Expect(sentPackets[0].Frames).To(HaveLen(2))
Expect(sentPackets[0].Frames).To(ContainElement(sf))
})
It("doesn't retransmit handshake packets when the handshake is complete", func() {
sess.handshakeComplete = true
sf := &wire.StreamFrame{StreamID: 1, Data: []byte("foobar")}
sph.retransmissionQueue = []*ackhandler.Packet{{
Frames: []wire.Frame{sf},
EncryptionLevel: protocol.EncryptionSecure,
}}
sph.EXPECT().DequeuePacketForRetransmission().Return(
&ackhandler.Packet{
Frames: []wire.Frame{sf},
EncryptionLevel: protocol.EncryptionSecure,
})
sph.EXPECT().DequeuePacketForRetransmission()
err := sess.sendPacket()
Expect(err).ToNot(HaveOccurred())
Expect(mconn.written).To(BeEmpty())
@@ -918,17 +897,19 @@ var _ = Describe("Session", func() {
StreamID: 0x5,
Data: []byte("foobar1234567"),
}
p := ackhandler.Packet{
PacketNumber: 0x1337,
Frames: []wire.Frame{&f},
EncryptionLevel: protocol.EncryptionForwardSecure,
}
sph.retransmissionQueue = []*ackhandler.Packet{&p}
sph.EXPECT().GetStopWaitingFrame(true).Return(&wire.StopWaitingFrame{})
sph.EXPECT().DequeuePacketForRetransmission().Return(
&ackhandler.Packet{
PacketNumber: 0x1337,
Frames: []wire.Frame{&f},
EncryptionLevel: protocol.EncryptionForwardSecure,
})
sph.EXPECT().DequeuePacketForRetransmission()
sph.EXPECT().SentPacket(gomock.Any())
sph.EXPECT().SendingAllowed()
err := sess.sendPacket()
Expect(err).NotTo(HaveOccurred())
Expect(mconn.written).To(HaveLen(1))
Expect(sph.requestedStopWaiting).To(BeTrue())
Expect(mconn.written).To(Receive(ContainSubstring("foobar1234567")))
})
@@ -941,17 +922,22 @@ var _ = Describe("Session", func() {
StreamID: 0x7,
Data: []byte("loremipsum"),
}
p1 := ackhandler.Packet{
p1 := &ackhandler.Packet{
PacketNumber: 0x1337,
Frames: []wire.Frame{&f1},
EncryptionLevel: protocol.EncryptionForwardSecure,
}
p2 := ackhandler.Packet{
p2 := &ackhandler.Packet{
PacketNumber: 0x1338,
Frames: []wire.Frame{&f2},
EncryptionLevel: protocol.EncryptionForwardSecure,
}
sph.retransmissionQueue = []*ackhandler.Packet{&p1, &p2}
sph.EXPECT().DequeuePacketForRetransmission().Return(p1)
sph.EXPECT().DequeuePacketForRetransmission().Return(p2)
sph.EXPECT().DequeuePacketForRetransmission()
sph.EXPECT().GetStopWaitingFrame(true).Return(&wire.StopWaitingFrame{})
sph.EXPECT().SentPacket(gomock.Any())
sph.EXPECT().SendingAllowed()
err := sess.sendPacket()
Expect(err).NotTo(HaveOccurred())
@@ -960,24 +946,6 @@ var _ = Describe("Session", func() {
Expect(packet).To(ContainSubstring("foobar"))
Expect(packet).To(ContainSubstring("loremipsum"))
})
It("always attaches a StopWaiting to a packet that contains a retransmission", func() {
f := &wire.StreamFrame{
StreamID: 0x5,
Data: bytes.Repeat([]byte{'f'}, int(1.5*float32(protocol.MaxPacketSize))),
}
sess.streamFramer.AddFrameForRetransmission(f)
err := sess.sendPacket()
Expect(err).NotTo(HaveOccurred())
Expect(mconn.written).To(HaveLen(2))
sentPackets := sph.sentPackets
Expect(sentPackets).To(HaveLen(2))
_, ok := sentPackets[0].Frames[0].(*wire.StopWaitingFrame)
Expect(ok).To(BeTrue())
_, ok = sentPackets[1].Frames[0].(*wire.StopWaitingFrame)
Expect(ok).To(BeTrue())
})
})
})