introduce a separate code path for unpacking short header packets

This commit is contained in:
Marten Seemann
2022-08-27 22:30:09 +03:00
parent ed15a94703
commit 4f3d3b36ac
7 changed files with 390 additions and 245 deletions

View File

@@ -25,7 +25,8 @@ import (
)
type unpacker interface {
Unpack(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error)
UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte) (*unpackedPacket, error)
UnpackShortHeader(rcvTime time.Time, data []byte) (*wire.ShortHeader, []byte, error)
}
type streamGetter interface {
@@ -362,7 +363,7 @@ var newConnection = func(
s.perspective,
s.version,
)
s.unpacker = newPacketUnpacker(cs, s.version)
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen, s.version)
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, s.oneRTTStream)
return s
}
@@ -474,7 +475,7 @@ var newClientConnection = func(
s.clientHelloWritten = clientHelloWritten
s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream())
s.unpacker = newPacketUnpacker(cs, s.version)
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen, s.version)
s.packer = newPacketPacker(
srcConnID,
s.connIDManager.Get,
@@ -858,58 +859,113 @@ func (s *connection) handlePacketImpl(rp *receivedPacket) bool {
if counter > 0 {
p = p.Clone()
p.data = data
}
hdr, packetData, rest, err := wire.ParsePacket(p.data, s.srcConnIDLen)
if err != nil {
if s.tracer != nil {
dropReason := logging.PacketDropHeaderParseError
if err == wire.ErrUnsupportedVersion {
dropReason = logging.PacketDropUnsupportedVersion
destConnID, err := wire.ParseConnectionID(p.data, s.srcConnIDLen)
if err != nil {
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError)
}
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), dropReason)
s.logger.Debugf("error parsing packet, couldn't parse connection ID: %s", err)
break
}
s.logger.Debugf("error parsing packet: %s", err)
if destConnID != lastConnID {
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID)
}
s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", destConnID, lastConnID)
break
}
}
if wire.IsLongHeaderPacket(p.data[0]) {
hdr, packetData, rest, err := wire.ParsePacket(p.data, s.srcConnIDLen)
if err != nil {
if s.tracer != nil {
dropReason := logging.PacketDropHeaderParseError
if err == wire.ErrUnsupportedVersion {
dropReason = logging.PacketDropUnsupportedVersion
}
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.ByteCount(len(data)), dropReason)
}
s.logger.Debugf("error parsing packet: %s", err)
break
}
lastConnID = hdr.DestConnectionID
if hdr.Version != s.version {
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion)
}
s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version)
break
}
if counter > 0 {
p.buffer.Split()
}
counter++
// only log if this actually a coalesced packet
if s.logger.Debug() && (counter > 1 || len(rest) > 0) {
s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest))
}
p.data = packetData
if wasProcessed := s.handleLongHeaderPacket(p, hdr); wasProcessed {
processed = true
}
data = rest
} else {
if counter > 0 {
p.buffer.Split()
}
processed = s.handleShortHeaderPacket(p)
break
}
if hdr.IsLongHeader && hdr.Version != s.version {
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion)
}
s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version)
break
}
if counter > 0 && hdr.DestConnectionID != lastConnID {
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID)
}
s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID)
break
}
lastConnID = hdr.DestConnectionID
if counter > 0 {
p.buffer.Split()
}
counter++
// only log if this actually a coalesced packet
if s.logger.Debug() && (counter > 1 || len(rest) > 0) {
s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest))
}
p.data = packetData
if wasProcessed := s.handleSinglePacket(p, hdr); wasProcessed {
processed = true
}
data = rest
}
p.buffer.MaybeRelease()
return processed
}
func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ {
func (s *connection) handleShortHeaderPacket(p *receivedPacket) bool {
var wasQueued bool
defer func() {
// Put back the packet buffer if the packet wasn't queued for later decryption.
if !wasQueued {
p.buffer.Decrement()
}
}()
hdr, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data)
if err != nil {
wasQueued = s.handleUnpackError(err, p, logging.PacketType1RTT)
return false
}
if s.logger.Debug() {
s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, 1-RTT", hdr.PacketNumber, p.Size(), hdr.DestConnectionID)
hdr.Log(s.logger)
}
if s.receivedPacketHandler.IsPotentiallyDuplicate(hdr.PacketNumber, protocol.Encryption1RTT) {
s.logger.Debugf("Dropping (potentially) duplicate packet.")
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketType1RTT, p.Size(), logging.PacketDropDuplicate)
}
return false
}
if err := s.handleUnpackedShortHeaderPacket(hdr, data, p.ecn, p.rcvTime, p.Size()); err != nil {
s.closeLocal(err)
return false
}
return true
}
func (s *connection) handleLongHeaderPacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ {
var wasQueued bool
defer func() {
@@ -925,7 +981,7 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo
// The server can change the source connection ID with the first Handshake packet.
// After this, all packets with a different source connection have to be ignored.
if s.receivedFirstPacket && hdr.IsLongHeader && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID {
if s.receivedFirstPacket && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID {
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeInitial, p.Size(), logging.PacketDropUnknownConnectionID)
}
@@ -940,44 +996,9 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo
return false
}
packet, err := s.unpacker.Unpack(hdr, p.rcvTime, p.data)
packet, err := s.unpacker.UnpackLongHeader(hdr, p.rcvTime, p.data)
if err != nil {
switch err {
case handshake.ErrKeysDropped:
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropKeyUnavailable)
}
s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", hdr.PacketType(), p.Size())
case handshake.ErrKeysNotYetAvailable:
// Sealer for this encryption level not yet available.
// Try again later.
wasQueued = true
s.tryQueueingUndecryptablePacket(p, logging.PacketTypeFromHeader(hdr))
case wire.ErrInvalidReservedBits:
s.closeLocal(&qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: err.Error(),
})
case handshake.ErrDecryptionFailed:
// This might be a packet injected by an attacker. Drop it.
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropPayloadDecryptError)
}
s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", hdr.PacketType(), p.Size(), err)
default:
var headerErr *headerParseError
if errors.As(err, &headerErr) {
// This might be a packet injected by an attacker. Drop it.
if s.tracer != nil {
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), p.Size(), logging.PacketDropHeaderParseError)
}
s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", hdr.PacketType(), p.Size(), err)
} else {
// This is an error returned by the AEAD (other than ErrDecryptionFailed).
// For example, a PROTOCOL_VIOLATION due to key updates.
s.closeLocal(err)
}
}
wasQueued = s.handleUnpackError(err, p, logging.PacketTypeFromHeader(hdr))
return false
}
@@ -1001,6 +1022,46 @@ func (s *connection) handleSinglePacket(p *receivedPacket, hdr *wire.Header) boo
return true
}
func (s *connection) handleUnpackError(err error, p *receivedPacket, pt logging.PacketType) (wasQueued bool) {
switch err {
case handshake.ErrKeysDropped:
if s.tracer != nil {
s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropKeyUnavailable)
}
s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size())
case handshake.ErrKeysNotYetAvailable:
// Sealer for this encryption level not yet available.
// Try again later.
s.tryQueueingUndecryptablePacket(p, pt)
return true
case wire.ErrInvalidReservedBits:
s.closeLocal(&qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: err.Error(),
})
case handshake.ErrDecryptionFailed:
// This might be a packet injected by an attacker. Drop it.
if s.tracer != nil {
s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropPayloadDecryptError)
}
s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err)
default:
var headerErr *headerParseError
if errors.As(err, &headerErr) {
// This might be a packet injected by an attacker. Drop it.
if s.tracer != nil {
s.tracer.DroppedPacket(pt, p.Size(), logging.PacketDropHeaderParseError)
}
s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err)
} else {
// This is an error returned by the AEAD (other than ErrDecryptionFailed).
// For example, a PROTOCOL_VIOLATION due to key updates.
s.closeLocal(err)
}
}
return false
}
func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was this a valid Retry */ {
if s.perspective == protocol.PerspectiveServer {
if s.tracer != nil {
@@ -1167,15 +1228,51 @@ func (s *connection) handleUnpackedPacket(
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
s.keepAlivePingSent = false
var log func([]logging.Frame)
if s.tracer != nil {
log = func(frames []logging.Frame) {
s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, frames)
}
}
isAckEliciting, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log)
if err != nil {
return err
}
return s.receivedPacketHandler.ReceivedPacket(packet.hdr.PacketNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting)
}
func (s *connection) handleUnpackedShortHeaderPacket(hdr *wire.ShortHeader, data []byte, ecn protocol.ECN, rcvTime time.Time, packetSize protocol.ByteCount) error {
s.lastPacketReceivedTime = rcvTime
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
s.keepAlivePingSent = false
var log func([]logging.Frame)
if s.tracer != nil {
log = func(frames []logging.Frame) {
s.tracer.ReceivedShortHeaderPacket(hdr, packetSize, frames)
}
}
isAckEliciting, err := s.handleFrames(data, hdr.DestConnectionID, protocol.Encryption1RTT, log)
if err != nil {
return err
}
return s.receivedPacketHandler.ReceivedPacket(hdr.PacketNumber, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting)
}
func (s *connection) handleFrames(
data []byte,
destConnID protocol.ConnectionID,
encLevel protocol.EncryptionLevel,
log func([]logging.Frame),
) (isAckEliciting bool, _ error) {
// Only used for tracing.
// If we're not tracing, this slice will always remain empty.
var frames []wire.Frame
r := bytes.NewReader(packet.data)
var isAckEliciting bool
r := bytes.NewReader(data)
for {
frame, err := s.frameParser.ParseNext(r, packet.encryptionLevel)
frame, err := s.frameParser.ParseNext(r, encLevel)
if err != nil {
return err
return false, err
}
if frame == nil {
break
@@ -1185,38 +1282,28 @@ func (s *connection) handleUnpackedPacket(
}
// Only process frames now if we're not logging.
// If we're logging, we need to make sure that the packet_received event is logged first.
if s.tracer == nil {
if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil {
return err
if log == nil {
if err := s.handleFrame(frame, encLevel, destConnID); err != nil {
return false, err
}
} else {
frames = append(frames, frame)
}
}
if s.tracer != nil {
if log != nil {
fs := make([]logging.Frame, len(frames))
for i, frame := range frames {
fs[i] = logutils.ConvertFrame(frame)
}
if packet.hdr.IsLongHeader {
s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, fs)
} else {
s.tracer.ReceivedShortHeaderPacket(&wire.ShortHeader{
DestConnectionID: packet.hdr.DestConnectionID,
PacketNumber: packet.hdr.PacketNumber,
PacketNumberLen: packet.hdr.PacketNumberLen,
KeyPhase: packet.hdr.KeyPhase,
}, packetSize, fs)
}
log(fs)
for _, frame := range frames {
if err := s.handleFrame(frame, packet.encryptionLevel, packet.hdr.DestConnectionID); err != nil {
return err
if err := s.handleFrame(frame, encLevel, destConnID); err != nil {
return false, err
}
}
}
return s.receivedPacketHandler.ReceivedPacket(packet.hdr.PacketNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting)
return
}
func (s *connection) handleFrame(f wire.Frame, encLevel protocol.EncryptionLevel, destConnID protocol.ConnectionID) error {