forked from quic-go/quic-go
introduce a separate code path for unpacking short header packets
This commit is contained in:
297
connection.go
297
connection.go
@@ -25,7 +25,8 @@ import (
|
||||
)
|
||||
|
||||
type unpacker interface {
|
||||
Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error)
|
||||
UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error)
|
||||
UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error)
|
||||
}
|
||||
|
||||
type streamGetter interface {
|
||||
@@ -362,7 +363,7 @@ var newConnection = func(
|
||||
s.perspective,
|
||||
s.version,
|
||||
)
|
||||
s.unpacker = newPacketUnpacker(cs, s.version)
|
||||
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen, s.version)
|
||||
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream)
|
||||
return s
|
||||
}
|
||||
@@ -474,7 +475,7 @@ var newClientConnection = func(
|
||||
s.clientHelloWritten = clientHelloWritten
|
||||
s.cryptoStreamHandler = cs
|
||||
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream())
|
||||
s.unpacker = newPacketUnpacker(cs, s.version)
|
||||
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen, s.version)
|
||||
s.packer = newPacketPacker(
|
||||
srcConnID,
|
||||
s.connIDManager.Get,
|
||||
@@ -858,58 +859,113 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
|
||||
if counter > 0 {
|
||||
p = p.Clone()
|
||||
p.data = data
|
||||
}
|
||||
|
||||
hdr, packetData, rest, err := wire.ParsePacket(p.data, s.srcConnIDLen)
|
||||
if err != nil {
|
||||
if s.tracer != nil {
|
||||
dropReason := logging.PacketDropHeaderParseError
|
||||
if err == wire.ErrUnsupportedVersion {
|
||||
dropReason = logging.PacketDropUnsupportedVersion
|
||||
destConnID, err := wire.ParseConnectionID(p.data, s.srcConnIDLen)
|
||||
if err != nil {
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), dropReason)
|
||||
s.logger.Debugf("error parsing packet, couldn't parse connection ID: %s", err)
|
||||
break
|
||||
}
|
||||
s.logger.Debugf("error parsing packet: %s", err)
|
||||
if destConnID != lastConnID {
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID)
|
||||
}
|
||||
s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", destConnID, lastConnID)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if wire.IsLongHeaderPacket(p.data[0]) {
|
||||
hdr, packetData, rest, err := wire.ParsePacket(p.data, s.srcConnIDLen)
|
||||
if err != nil {
|
||||
if s.tracer != nil {
|
||||
dropReason := logging.PacketDropHeaderParseError
|
||||
if err == wire.ErrUnsupportedVersion {
|
||||
dropReason = logging.PacketDropUnsupportedVersion
|
||||
}
|
||||
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), dropReason)
|
||||
}
|
||||
s.logger.Debugf("error parsing packet: %s", err)
|
||||
break
|
||||
}
|
||||
lastConnID = hdr.DestConnectionID
|
||||
|
||||
if hdr.Version != s.version {
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion)
|
||||
}
|
||||
s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version)
|
||||
break
|
||||
}
|
||||
|
||||
if counter > 0 {
|
||||
p.buffer.Split()
|
||||
}
|
||||
counter++
|
||||
|
||||
// only log if this actually a coalesced packet
|
||||
if s.logger.Debug() && (counter > 1 || len(rest) > 0) {
|
||||
s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest))
|
||||
}
|
||||
|
||||
p.data = packetData
|
||||
|
||||
if wasProcessed := s.handleLongHeaderPacket(p, hdr); wasProcessed {
|
||||
processed = true
|
||||
}
|
||||
data = rest
|
||||
} else {
|
||||
if counter > 0 {
|
||||
p.buffer.Split()
|
||||
}
|
||||
processed = s.handleShortHeaderPacket(p)
|
||||
break
|
||||
}
|
||||
|
||||
if hdr.IsLongHeader && hdr.Version != s.version {
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion)
|
||||
}
|
||||
s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version)
|
||||
break
|
||||
}
|
||||
|
||||
if counter > 0 && hdr.DestConnectionID != lastConnID {
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID)
|
||||
}
|
||||
s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID)
|
||||
break
|
||||
}
|
||||
lastConnID = hdr.DestConnectionID
|
||||
|
||||
if counter > 0 {
|
||||
p.buffer.Split()
|
||||
}
|
||||
counter++
|
||||
|
||||
// only log if this actually a coalesced packet
|
||||
if s.logger.Debug() && (counter > 1 || len(rest) > 0) {
|
||||
s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest))
|
||||
}
|
||||
p.data = packetData
|
||||
if wasProcessed := s.handleSinglePacket(p, hdr); wasProcessed {
|
||||
processed = true
|
||||
}
|
||||
data = rest
|
||||
}
|
||||
|
||||
p.buffer.MaybeRelease()
|
||||
return processed
|
||||
}
|
||||
|
||||
func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ {
|
||||
func (s *connection) handleShortHeaderPacket(p *receivedPacket) bool {
|
||||
var wasQueued bool
|
||||
|
||||
defer func() {
|
||||
// Put back the packet buffer if the packet wasn't queued for later decryption.
|
||||
if !wasQueued {
|
||||
p.buffer.Decrement()
|
||||
}
|
||||
}()
|
||||
|
||||
hdr, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data)
|
||||
if err != nil {
|
||||
wasQueued = s.handleUnpackError(err, p, logging.PacketType1RTT)
|
||||
return false
|
||||
}
|
||||
|
||||
if s.logger.Debug() {
|
||||
s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, 1-RTT", hdr.PacketNumber, p.Size(), hdr.DestConnectionID)
|
||||
hdr.Log(s.logger)
|
||||
}
|
||||
|
||||
if s.receivedPacketHandler.IsPotentiallyDuplicate(hdr.PacketNumber, protocol.Encryption1RTT) {
|
||||
s.logger.Debugf("Dropping (potentially) duplicate packet.")
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if err := s.handleUnpackedShortHeaderPacket(hdr, data, p.ecn, p.rcvTime, p.Size()); err != nil {
|
||||
s.closeLocal(err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ {
|
||||
var wasQueued bool
|
||||
|
||||
defer func() {
|
||||
@@ -925,7 +981,7 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo
|
||||
|
||||
// The server can change the source connection ID with the first Handshake packet.
|
||||
// After this, all packets with a different source connection have to be ignored.
|
||||
if s.receivedFirstPacket && hdr.IsLongHeader && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID {
|
||||
if s.receivedFirstPacket && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID {
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID)
|
||||
}
|
||||
@@ -940,44 +996,9 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo
|
||||
return false
|
||||
}
|
||||
|
||||
packet, err := s.unpacker.Unpack(hdr, p.rcvTime, p.data)
|
||||
packet, err := s.unpacker.UnpackLongHeader(hdr, p.rcvTime, p.data)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case handshake.ErrKeysDropped:
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropKeyUnavailable)
|
||||
}
|
||||
s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", hdr.PacketType(), p.Size())
|
||||
case handshake.ErrKeysNotYetAvailable:
|
||||
// Sealer for this encryption level not yet available.
|
||||
// Try again later.
|
||||
wasQueued = true
|
||||
s.tryQueueingUndecryptablePacket(p, logging.PacketTypeFromHeader(hdr))
|
||||
case wire.ErrInvalidReservedBits:
|
||||
s.closeLocal(&qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: err.Error(),
|
||||
})
|
||||
case handshake.ErrDecryptionFailed:
|
||||
// This might be a packet injected by an attacker. Drop it.
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropPayloadDecryptError)
|
||||
}
|
||||
s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", hdr.PacketType(), p.Size(), err)
|
||||
default:
|
||||
var headerErr *headerParseError
|
||||
if errors.As(err, &headerErr) {
|
||||
// This might be a packet injected by an attacker. Drop it.
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", hdr.PacketType(), p.Size(), err)
|
||||
} else {
|
||||
// This is an error returned by the AEAD (other than ErrDecryptionFailed).
|
||||
// For example, a PROTOCOL_VIOLATION due to key updates.
|
||||
s.closeLocal(err)
|
||||
}
|
||||
}
|
||||
wasQueued = s.handleUnpackError(err, p, logging.PacketTypeFromHeader(hdr))
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1001,6 +1022,46 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *connection) handleUnpackError(err error, p *receivedPacket, pt logging.PacketType) (wasQueued bool) {
|
||||
switch err {
|
||||
case handshake.ErrKeysDropped:
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropKeyUnavailable)
|
||||
}
|
||||
s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size())
|
||||
case handshake.ErrKeysNotYetAvailable:
|
||||
// Sealer for this encryption level not yet available.
|
||||
// Try again later.
|
||||
s.tryQueueingUndecryptablePacket(p, pt)
|
||||
return true
|
||||
case wire.ErrInvalidReservedBits:
|
||||
s.closeLocal(&qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
ErrorMessage: err.Error(),
|
||||
})
|
||||
case handshake.ErrDecryptionFailed:
|
||||
// This might be a packet injected by an attacker. Drop it.
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropPayloadDecryptError)
|
||||
}
|
||||
s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err)
|
||||
default:
|
||||
var headerErr *headerParseError
|
||||
if errors.As(err, &headerErr) {
|
||||
// This might be a packet injected by an attacker. Drop it.
|
||||
if s.tracer != nil {
|
||||
s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropHeaderParseError)
|
||||
}
|
||||
s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err)
|
||||
} else {
|
||||
// This is an error returned by the AEAD (other than ErrDecryptionFailed).
|
||||
// For example, a PROTOCOL_VIOLATION due to key updates.
|
||||
s.closeLocal(err)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was this a valid Retry */ {
|
||||
if s.perspective == protocol.PerspectiveServer {
|
||||
if s.tracer != nil {
|
||||
@@ -1167,15 +1228,51 @@ func (s *connection) handleUnpackedPacket(
|
||||
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
|
||||
s.keepAlivePingSent = false
|
||||
|
||||
var log func([]logging.Frame)
|
||||
if s.tracer != nil {
|
||||
log = func(frames []logging.Frame) {
|
||||
s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, frames)
|
||||
}
|
||||
}
|
||||
isAckEliciting, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.receivedPacketHandler.ReceivedPacket(packet.hdr.PacketNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting)
|
||||
}
|
||||
|
||||
func (s *connection) handleUnpackedShortHeaderPacket(hdr *wire.ShortHeader, data []byte, ecn protocol.ECN, rcvTime time.Time, packetSize protocol.ByteCount) error {
|
||||
s.lastPacketReceivedTime = rcvTime
|
||||
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
|
||||
s.keepAlivePingSent = false
|
||||
|
||||
var log func([]logging.Frame)
|
||||
if s.tracer != nil {
|
||||
log = func(frames []logging.Frame) {
|
||||
s.tracer.ReceivedShortHeaderPacket(hdr, packetSize, frames)
|
||||
}
|
||||
}
|
||||
isAckEliciting, err := s.handleFrames(data, hdr.DestConnectionID, protocol.Encryption1RTT, log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting)
|
||||
}
|
||||
|
||||
func (s *connection) handleFrames(
|
||||
data []byte,
|
||||
destConnID protocol.ConnectionID,
|
||||
encLevel protocol.EncryptionLevel,
|
||||
log func([]logging.Frame),
|
||||
) (isAckEliciting bool, _ error) {
|
||||
// Only used for tracing.
|
||||
// If we're not tracing, this slice will always remain empty.
|
||||
var frames []wire.Frame
|
||||
r := bytes.NewReader(packet.data)
|
||||
var isAckEliciting bool
|
||||
r := bytes.NewReader(data)
|
||||
for {
|
||||
frame, err := s.frameParser.ParseNext(r, packet.encryptionLevel)
|
||||
frame, err := s.frameParser.ParseNext(r, encLevel)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
if frame == nil {
|
||||
break
|
||||
@@ -1185,38 +1282,28 @@ func (s *connection) handleUnpackedPacket(
|
||||
}
|
||||
// Only process frames now if we're not logging.
|
||||
// If we're logging, we need to make sure that the packet_received event is logged first.
|
||||
if s.tracer == nil {
|
||||
if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil {
|
||||
return err
|
||||
if log == nil {
|
||||
if err := s.handleFrame(frame, encLevel, destConnID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
} else {
|
||||
frames = append(frames, frame)
|
||||
}
|
||||
}
|
||||
|
||||
if s.tracer != nil {
|
||||
if log != nil {
|
||||
fs := make([]logging.Frame, len(frames))
|
||||
for i, frame := range frames {
|
||||
fs[i] = logutils.ConvertFrame(frame)
|
||||
}
|
||||
if packet.hdr.IsLongHeader {
|
||||
s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, fs)
|
||||
} else {
|
||||
s.tracer.ReceivedShortHeaderPacket(&wire.ShortHeader{
|
||||
DestConnectionID: packet.hdr.DestConnectionID,
|
||||
PacketNumber: packet.hdr.PacketNumber,
|
||||
PacketNumberLen: packet.hdr.PacketNumberLen,
|
||||
KeyPhase: packet.hdr.KeyPhase,
|
||||
}, packetSize, fs)
|
||||
}
|
||||
log(fs)
|
||||
for _, frame := range frames {
|
||||
if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil {
|
||||
return err
|
||||
if err := s.handleFrame(frame, encLevel, destConnID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.receivedPacketHandler.ReceivedPacket(packet.hdr.PacketNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, destConnID protocol.ConnectionID) error {
|
||||
|
||||
Reference in New Issue
Block a user