diff --git a/packet_packer.go b/packet_packer.go index c96307cbc..9d3073d73 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -215,7 +215,7 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac hdr = p.getShortHeader(s.KeyPhase()) } - return p.writeAndSealPacket(hdr, payload, encLevel, sealer) + return p.writeSinglePacket(hdr, payload, encLevel, sealer) } func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacket, error) { @@ -251,7 +251,7 @@ func (p *packetPacker) MaybePackAckPacket(handshakeConfirmed bool) (*packedPacke if err != nil { return nil, err } - return p.writeAndSealPacket(hdr, payload, encLevel, sealer) + return p.writeSinglePacket(hdr, payload, encLevel, sealer) } // PackPacket packs a new packet. @@ -357,7 +357,7 @@ func (p *packetPacker) packCryptoPacket( payload.frames = []ackhandler.Frame{{Frame: cf}} payload.length += cf.Length(p.version) } - return p.writeAndSealPacket(hdr, payload, encLevel, sealer) + return p.writeSinglePacket(hdr, payload, encLevel, sealer) } func (p *packetPacker) maybePackAppDataPacket() (*packedPacket, error) { @@ -403,7 +403,7 @@ func (p *packetPacker) maybePackAppDataPacket() (*packedPacket, error) { p.numNonAckElicitingAcks = 0 } - return p.writeAndSealPacket(header, payload, encLevel, sealer) + return p.writeSinglePacket(header, payload, encLevel, sealer) } func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) payload { @@ -529,15 +529,37 @@ func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.Ex return hdr } -func (p *packetPacker) writeAndSealPacket( +// writeSinglePacket packs a single packet. +func (p *packetPacker) writeSinglePacket( header *wire.ExtendedHeader, payload payload, encLevel protocol.EncryptionLevel, sealer sealer, ) (*packedPacket, error) { + packetBuffer := getPacketBuffer() + + n, err := p.appendPacket(packetBuffer.Slice[:0], header, payload, encLevel, sealer) + if err != nil { + return nil, err + } + return &packedPacket{ + header: header, + raw: packetBuffer.Slice[:n], + ack: payload.ack, + frames: payload.frames, + buffer: packetBuffer, + }, nil +} + +func (p *packetPacker) appendPacket( + raw []byte, + header *wire.ExtendedHeader, + payload payload, + encLevel protocol.EncryptionLevel, + sealer sealer, +) (int, error) { var paddingLen protocol.ByteCount pnLen := protocol.ByteCount(header.PacketNumberLen) - if encLevel != protocol.Encryption1RTT { if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial { headerLen := header.GetLength(p.version) @@ -549,27 +571,17 @@ func (p *packetPacker) writeAndSealPacket( } else if payload.length < 4-pnLen { paddingLen = 4 - pnLen - payload.length } - return p.writeAndSealPacketWithPadding(header, payload, paddingLen, encLevel, sealer) -} - -func (p *packetPacker) writeAndSealPacketWithPadding( - header *wire.ExtendedHeader, - payload payload, - paddingLen protocol.ByteCount, - encLevel protocol.EncryptionLevel, - sealer sealer, -) (*packedPacket, error) { - packetBuffer := getPacketBuffer() - buffer := bytes.NewBuffer(packetBuffer.Slice[:0]) + hdrOffset := len(raw) + buffer := bytes.NewBuffer(raw) if err := header.Write(buffer, p.version); err != nil { - return nil, err + return 0, err } payloadOffset := buffer.Len() if payload.ack != nil { if err := payload.ack.Write(buffer, p.version); err != nil { - return nil, err + return 0, err } } if paddingLen > 0 { @@ -577,40 +589,29 @@ func (p *packetPacker) writeAndSealPacketWithPadding( } for _, frame := range payload.frames { if err := frame.Write(buffer, p.version); err != nil { - return nil, err + return 0, err } } if payloadSize := protocol.ByteCount(buffer.Len()-payloadOffset) - paddingLen; payloadSize != payload.length { - fmt.Printf("%#v\n", payload) - return nil, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize) + return 0, fmt.Errorf("PacketPacker BUG: payload size inconsistent (expected %d, got %d bytes)", payload.length, payloadSize) } if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize { - return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) + return 0, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) } - raw := buffer.Bytes() - _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[:payloadOffset]) + raw = raw[:buffer.Len()] + _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[hdrOffset:payloadOffset]) raw = raw[0 : buffer.Len()+sealer.Overhead()] pnOffset := payloadOffset - int(header.PacketNumberLen) - sealer.EncryptHeader( - raw[pnOffset+4:pnOffset+4+16], - &raw[0], - raw[pnOffset:payloadOffset], - ) + sealer.EncryptHeader(raw[pnOffset+4:pnOffset+4+16], &raw[0], raw[pnOffset:payloadOffset]) num := p.pnManager.PopPacketNumber(encLevel) if num != header.PacketNumber { - return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") + return 0, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") } - return &packedPacket{ - header: header, - raw: raw, - ack: payload.ack, - frames: payload.frames, - buffer: packetBuffer, - }, nil + return len(raw) - hdrOffset, nil } func (p *packetPacker) SetToken(token []byte) {