forked from quic-go/quic-go
58
session.go
58
session.go
@@ -64,6 +64,7 @@ type Session struct {
|
||||
lastRcvdPacketNumber protocol.PacketNumber
|
||||
|
||||
rttStats congestion.RTTStats
|
||||
congestion congestion.SendAlgorithm
|
||||
}
|
||||
|
||||
// NewSession makes a new session
|
||||
@@ -97,6 +98,14 @@ func NewSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
||||
session.packer = &packetPacker{aead: cryptoSetup, connectionParametersManager: session.connectionParametersManager, connectionID: connectionID}
|
||||
session.unpacker = &packetUnpacker{aead: cryptoSetup}
|
||||
|
||||
session.congestion = congestion.NewCubicSender(
|
||||
congestion.DefaultClock{},
|
||||
&session.rttStats,
|
||||
false, /* don't use reno since chromium doesn't (why?) */
|
||||
protocol.InitialCongestionWindow,
|
||||
protocol.DefaultMaxCongestionWindow,
|
||||
)
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
@@ -175,13 +184,9 @@ func (s *Session) handlePacket(remoteAddr interface{}, publicHeader *PublicHeade
|
||||
utils.Debugf("\t<- &frames.StreamFrame{StreamID: %d, FinBit: %t, Offset: %d}", frame.StreamID, frame.FinBit, frame.Offset)
|
||||
err = s.handleStreamFrame(frame)
|
||||
case *frames.AckFrame:
|
||||
var duration time.Duration
|
||||
duration, _, _, err = s.sentPacketHandler.ReceivedAck(frame)
|
||||
s.rttStats.UpdateRTT(duration, frame.DelayTime, time.Now())
|
||||
utils.Debugf("\t<- %#v", frame)
|
||||
utils.Debugf("\tEstimated RTT: %dms", s.rttStats.SmoothedRTT()/time.Millisecond)
|
||||
// ToDo: send right error in ConnectionClose frame
|
||||
err = s.handleAckFrame(frame)
|
||||
case *frames.ConnectionCloseFrame:
|
||||
// ToDo: send right error in ConnectionClose frame
|
||||
utils.Debugf("\t<- %#v", frame)
|
||||
s.Close(nil, false)
|
||||
case *frames.StopWaitingFrame:
|
||||
@@ -269,6 +274,37 @@ func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) handleAckFrame(frame *frames.AckFrame) error {
|
||||
duration, acked, lost, err := s.sentPacketHandler.ReceivedAck(frame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: Don't always update RTT
|
||||
s.rttStats.UpdateRTT(duration, frame.DelayTime, time.Now())
|
||||
|
||||
cAcked := make(congestion.PacketVector, len(acked))
|
||||
for i, v := range acked {
|
||||
cAcked[i].Number = v.PacketNumber
|
||||
cAcked[i].Length = v.Length
|
||||
}
|
||||
cLost := make(congestion.PacketVector, len(lost))
|
||||
for i, v := range lost {
|
||||
cLost[i].Number = v.PacketNumber
|
||||
cLost[i].Length = v.Length
|
||||
}
|
||||
s.congestion.OnCongestionEvent(
|
||||
true, /* rtt updated */
|
||||
s.sentPacketHandler.BytesInFlight(),
|
||||
cAcked,
|
||||
cLost,
|
||||
)
|
||||
|
||||
utils.Debugf("\t<- %#v", frame)
|
||||
utils.Debugf("\tEstimated RTT: %dms", s.rttStats.SmoothedRTT()/time.Millisecond)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close the connection
|
||||
func (s *Session) Close(e error, sendConnectionClose bool) error {
|
||||
if s.closed {
|
||||
@@ -367,6 +403,14 @@ func (s *Session) sendPacket() error {
|
||||
return err
|
||||
}
|
||||
|
||||
s.congestion.OnPacketSent(
|
||||
time.Now(),
|
||||
s.sentPacketHandler.BytesInFlight(),
|
||||
packet.number,
|
||||
protocol.ByteCount(len(packet.raw)),
|
||||
true, /* TODO: is retransmittable */
|
||||
)
|
||||
|
||||
s.stopWaitingManager.SentStopWaitingWithPacket(packet.number)
|
||||
|
||||
utils.Debugf("-> Sending packet 0x%x (%d bytes)", packet.number, len(packet.raw))
|
||||
@@ -452,5 +496,5 @@ func (s *Session) scheduleSending() {
|
||||
}
|
||||
|
||||
func (s *Session) congestionAllowsSending() bool {
|
||||
return s.sentPacketHandler.BytesInFlight() < 100*protocol.DefaultTCPMSS
|
||||
return s.sentPacketHandler.BytesInFlight() <= s.congestion.GetCongestionWindow()
|
||||
}
|
||||
|
||||
114
session_test.go
114
session_test.go
@@ -10,6 +10,8 @@ import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/ackhandler"
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lucas-clemente/quic-go/crypto"
|
||||
"github.com/lucas-clemente/quic-go/frames"
|
||||
"github.com/lucas-clemente/quic-go/handshake"
|
||||
@@ -29,6 +31,52 @@ func (m *mockConnection) write(p []byte) error {
|
||||
|
||||
func (*mockConnection) setCurrentRemoteAddr(addr interface{}) {}
|
||||
|
||||
type mockCongestion struct {
|
||||
nCalls int
|
||||
argsOnPacketSent []interface{}
|
||||
argsOnCongestionEvent []interface{}
|
||||
}
|
||||
|
||||
func (m *mockCongestion) TimeUntilSend(now time.Time, bytesInFlight protocol.ByteCount) time.Duration {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *mockCongestion) OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool) bool {
|
||||
m.nCalls++
|
||||
m.argsOnPacketSent = []interface{}{sentTime, bytesInFlight, packetNumber, bytes, isRetransmittable}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *mockCongestion) GetCongestionWindow() protocol.ByteCount {
|
||||
m.nCalls++
|
||||
return protocol.DefaultTCPMSS
|
||||
}
|
||||
|
||||
func (m *mockCongestion) OnCongestionEvent(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets congestion.PacketVector, lostPackets congestion.PacketVector) {
|
||||
m.nCalls++
|
||||
m.argsOnCongestionEvent = []interface{}{rttUpdated, bytesInFlight, ackedPackets, lostPackets}
|
||||
}
|
||||
|
||||
func (m *mockCongestion) SetNumEmulatedConnections(n int) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *mockCongestion) OnRetransmissionTimeout(packetsRetransmitted bool) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *mockCongestion) OnConnectionMigration() {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *mockCongestion) RetransmissionDelay() time.Duration {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (m *mockCongestion) SetSlowStartLargeReduction(enabled bool) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// TODO: Reorganize
|
||||
var _ = Describe("Session", func() {
|
||||
var (
|
||||
@@ -357,4 +405,70 @@ var _ = Describe("Session", func() {
|
||||
Expect(conn.written[0]).To(ContainSubstring("No recent network activity."))
|
||||
close(done)
|
||||
}, 0.5)
|
||||
|
||||
Context("congestion", func() {
|
||||
var (
|
||||
cong *mockCongestion
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
signer, err := crypto.NewRSASigner(testdata.GetTLSConfig())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer)
|
||||
session = NewSession(conn, 0, 0, scfg, nil, func(*Session) {}).(*Session)
|
||||
|
||||
cong = &mockCongestion{}
|
||||
session.congestion = cong
|
||||
})
|
||||
|
||||
It("should call OnSent", func() {
|
||||
session.QueueStreamFrame(&frames.StreamFrame{})
|
||||
session.sendPacket()
|
||||
Expect(cong.nCalls).To(Equal(2)) // OnPacketSent + GetCongestionWindow
|
||||
Expect(cong.argsOnPacketSent[1]).To(Equal(protocol.ByteCount(35)))
|
||||
Expect(cong.argsOnPacketSent[2]).To(Equal(protocol.PacketNumber(1)))
|
||||
Expect(cong.argsOnPacketSent[3]).To(Equal(protocol.ByteCount(35)))
|
||||
Expect(cong.argsOnPacketSent[4]).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should call OnCongestionEvent", func() {
|
||||
session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 1, Frames: []frames.Frame{}, Length: 1})
|
||||
session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 2, Frames: []frames.Frame{}, Length: 2})
|
||||
session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 3, Frames: []frames.Frame{}, Length: 3})
|
||||
err := session.handleAckFrame(&frames.AckFrame{
|
||||
LargestObserved: 3,
|
||||
NackRanges: []frames.NackRange{{2, 2}},
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(cong.nCalls).To(Equal(1))
|
||||
//(rttUpdated bool, bytesInFlight protocol.ByteCount, ackedPackets cong.PacketVector, lostPackets cong.PacketVector)
|
||||
Expect(cong.argsOnCongestionEvent[0]).To(BeTrue())
|
||||
Expect(cong.argsOnCongestionEvent[1]).To(Equal(protocol.ByteCount(2)))
|
||||
Expect(cong.argsOnCongestionEvent[2]).To(Equal(congestion.PacketVector{{1, 1}, {3, 3}}))
|
||||
Expect(cong.argsOnCongestionEvent[3]).To(Equal(congestion.PacketVector{}))
|
||||
|
||||
// Loose the packet
|
||||
session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 4, Frames: []frames.Frame{}, Length: 4})
|
||||
err = session.handleAckFrame(&frames.AckFrame{
|
||||
LargestObserved: 4,
|
||||
NackRanges: []frames.NackRange{{2, 2}},
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 5, Frames: []frames.Frame{}, Length: 5})
|
||||
err = session.handleAckFrame(&frames.AckFrame{
|
||||
LargestObserved: 5,
|
||||
NackRanges: []frames.NackRange{{2, 2}},
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: 6, Frames: []frames.Frame{}, Length: 6})
|
||||
err = session.handleAckFrame(&frames.AckFrame{
|
||||
LargestObserved: 6,
|
||||
NackRanges: []frames.NackRange{{2, 2}},
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(cong.argsOnCongestionEvent[2]).To(Equal(congestion.PacketVector{{6, 6}}))
|
||||
Expect(cong.argsOnCongestionEvent[3]).To(Equal(congestion.PacketVector{{2, 2}}))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user