Files
quic-go/packet_packer.go

261 lines
8.2 KiB
Go

package quic
import (
"bytes"
"errors"
"sync/atomic"
"github.com/lucas-clemente/quic-go/ackhandler"
"github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
)
type packedPacket struct {
number protocol.PacketNumber
entropyBit bool
raw []byte
frames []frames.Frame
}
type packetPacker struct {
connectionID protocol.ConnectionID
version protocol.VersionNumber
aead crypto.AEAD
sentPacketHandler ackhandler.SentPacketHandler
connectionParametersManager *handshake.ConnectionParametersManager
streamFrameQueue streamFrameQueue
controlFrames []frames.Frame
blockedManager *blockedManager
lastPacketNumber protocol.PacketNumber
}
func newPacketPacker(connectionID protocol.ConnectionID, aead crypto.AEAD, sentPacketHandler ackhandler.SentPacketHandler, connectionParametersHandler *handshake.ConnectionParametersManager, blockedManager *blockedManager, version protocol.VersionNumber) *packetPacker {
return &packetPacker{
aead: aead,
connectionID: connectionID,
connectionParametersManager: connectionParametersHandler,
version: version,
sentPacketHandler: sentPacketHandler,
blockedManager: blockedManager,
}
}
func (p *packetPacker) AddStreamFrame(f frames.StreamFrame) {
p.streamFrameQueue.Push(&f, false)
}
func (p *packetPacker) AddHighPrioStreamFrame(f frames.StreamFrame) {
p.streamFrameQueue.Push(&f, true)
}
func (p *packetPacker) AddBlocked(streamID protocol.StreamID, byteOffset protocol.ByteCount) {
// TODO: send out connection-level BlockedFrames at the right time
// see https://github.com/lucas-clemente/quic-go/issues/113
// TODO: remove this function completely once #113 is resolved
if streamID == 0 {
p.controlFrames = append(p.controlFrames, &frames.BlockedFrame{StreamID: 0})
}
p.blockedManager.AddBlockedStream(streamID, byteOffset)
}
func (p *packetPacker) PackConnectionClose(frame *frames.ConnectionCloseFrame) (*packedPacket, error) {
return p.packPacket(nil, []frames.Frame{frame}, true)
}
func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame) (*packedPacket, error) {
return p.packPacket(stopWaitingFrame, controlFrames, false)
}
func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, onlySendOneControlFrame bool) (*packedPacket, error) {
// don't send out packets that only contain a StopWaitingFrame
if len(p.controlFrames) == 0 && len(controlFrames) == 0 && p.streamFrameQueue.Len() == 0 {
return nil, nil
}
if len(controlFrames) > 0 {
p.controlFrames = append(p.controlFrames, controlFrames...)
}
currentPacketNumber := protocol.PacketNumber(atomic.AddUint64(
(*uint64)(&p.lastPacketNumber),
1,
))
packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(currentPacketNumber, p.sentPacketHandler.GetLargestObserved())
responsePublicHeader := &publicHeader{
ConnectionID: p.connectionID,
PacketNumber: currentPacketNumber,
PacketNumberLen: packetNumberLen,
TruncateConnectionID: p.connectionParametersManager.TruncateConnectionID(),
DiversificationNonce: p.aead.DiversificationNonce(),
}
publicHeaderLength, err := responsePublicHeader.GetLength()
if err != nil {
return nil, err
}
if stopWaitingFrame != nil {
stopWaitingFrame.PacketNumber = currentPacketNumber
stopWaitingFrame.PacketNumberLen = packetNumberLen
}
var payloadFrames []frames.Frame
if onlySendOneControlFrame {
payloadFrames = []frames.Frame{controlFrames[0]}
} else {
payloadFrames, err = p.composeNextPacket(stopWaitingFrame, publicHeaderLength)
if err != nil {
return nil, err
}
}
payload, err := p.getPayload(payloadFrames, currentPacketNumber)
if err != nil {
return nil, err
}
entropyBit, err := utils.RandomBit()
if err != nil {
return nil, err
}
if entropyBit {
payload[0] = 1
}
var raw bytes.Buffer
if err := responsePublicHeader.WritePublicHeader(&raw, p.version); err != nil {
return nil, err
}
ciphertext := p.aead.Seal(currentPacketNumber, raw.Bytes(), payload)
raw.Write(ciphertext)
if protocol.ByteCount(raw.Len()) > protocol.MaxPacketSize {
return nil, errors.New("PacketPacker BUG: packet too large")
}
return &packedPacket{
number: currentPacketNumber,
entropyBit: entropyBit,
raw: raw.Bytes(),
frames: payloadFrames,
}, nil
}
func (p *packetPacker) getPayload(frames []frames.Frame, currentPacketNumber protocol.PacketNumber) ([]byte, error) {
var payload bytes.Buffer
payload.WriteByte(0) // The entropy bit is set in sendPayload
for _, frame := range frames {
frame.Write(&payload, p.version)
}
return payload.Bytes(), nil
}
func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFrame, publicHeaderLength protocol.ByteCount) ([]frames.Frame, error) {
var payloadLength protocol.ByteCount
var payloadFrames []frames.Frame
maxFrameSize := protocol.MaxFrameAndPublicHeaderSize - publicHeaderLength
if stopWaitingFrame != nil {
payloadFrames = append(payloadFrames, stopWaitingFrame)
minLength, err := stopWaitingFrame.MinLength()
if err != nil {
return nil, err
}
payloadLength += minLength
}
for len(p.controlFrames) > 0 {
frame := p.controlFrames[0]
minLength, _ := frame.MinLength() // controlFrames does not contain any StopWaitingFrames. So it will *never* return an error
if payloadLength+minLength > maxFrameSize {
break
}
payloadFrames = append(payloadFrames, frame)
payloadLength += minLength
p.controlFrames = p.controlFrames[1:]
}
if payloadLength > maxFrameSize {
return nil, errors.New("PacketPacker BUG: packet payload too large")
}
hasStreamFrames := false
// temporarily increase the maxFrameSize by 2 bytes
// this leads to a properly sized packet in all cases, since we do all the packet length calculations with StreamFrames that have the DataLen set
// however, for the last StreamFrame in the packet, we can omit the DataLen, thus saving 2 bytes and yielding a packet of exactly the correct size
maxFrameSize += 2
for p.streamFrameQueue.Len() > 0 {
frame := p.streamFrameQueue.Front()
frame.DataLenPresent = true // set the dataLen by default. Remove them later if applicable
if payloadLength > maxFrameSize {
return nil, errors.New("PacketPacker BUG: packet payload too large")
}
// Does the frame fit into the remaining space?
frameMinLength, _ := frame.MinLength() // StreamFrame.MinLength *never* returns an error
if payloadLength+frameMinLength > maxFrameSize {
break
}
// Split stream frames if necessary
previousFrame := frame.MaybeSplitOffFrame(maxFrameSize - payloadLength)
if previousFrame != nil {
// Don't pop the queue, leave the modified frame in
frame = previousFrame
payloadLength += protocol.ByteCount(len(previousFrame.Data)) - 1
} else {
p.streamFrameQueue.Pop()
payloadLength += protocol.ByteCount(len(frame.Data)) - 1
}
blockedFrame := p.blockedManager.GetBlockedFrame(frame.StreamID, frame.Offset+protocol.ByteCount(len(frame.Data)))
if blockedFrame != nil {
blockedLength, _ := blockedFrame.MinLength() // BlockedFrame.MinLength *never* returns an error
if payloadLength+frameMinLength+blockedLength <= maxFrameSize {
payloadFrames = append(payloadFrames, blockedFrame)
payloadLength += blockedLength
} else {
p.controlFrames = append(p.controlFrames, blockedFrame)
}
}
payloadLength += frameMinLength
payloadFrames = append(payloadFrames, frame)
hasStreamFrames = true
}
// remove the dataLen for the last StreamFrame in the packet
if hasStreamFrames {
lastStreamFrame, ok := payloadFrames[len(payloadFrames)-1].(*frames.StreamFrame)
if !ok {
return nil, errors.New("PacketPacker BUG: StreamFrame type assertion failed")
}
lastStreamFrame.DataLenPresent = false
// payloadLength -= 2
}
return payloadFrames, nil
}
// Empty returns true if no frames are queued
func (p *packetPacker) Empty() bool {
return p.streamFrameQueue.Front() == nil
}
func (p *packetPacker) StreamFrameQueueByteLen() protocol.ByteCount {
return p.streamFrameQueue.ByteLen()
}