diff --git a/packet_packer.go b/packet_packer.go index 8c7078380..cc170f4ff 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -311,11 +311,17 @@ func (p *packetPacker) packConnectionClose( paddingLen = p.initialPaddingLen(payloads[i].frames, size) } if encLevel == protocol.Encryption1RTT { - shortHdrPacket, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, sealers[i], false) + ap, ack, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, keyPhase, payloads[i], paddingLen, sealers[i], false) if err != nil { return nil, err } - packet.shortHdrPacket = shortHdrPacket + packet.shortHdrPacket = &shortHeaderPacket{ + Packet: ap, + DestConnID: connID, + Ack: ack, + PacketNumberLen: oneRTTPacketNumberLen, + KeyPhase: keyPhase, + } } else { longHdrPacket, err := p.appendLongHeaderPacket(buffer, hdrs[i], payloads[i], paddingLen, encLevel, sealers[i]) if err != nil { @@ -472,11 +478,17 @@ func (p *packetPacker) PackCoalescedPacket(onlyAck bool) (*coalescedPacket, erro } packet.longHdrPackets = append(packet.longHdrPackets, longHdrPacket) } else if oneRTTPayload != nil { - shortHdrPacket, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, oneRTTSealer, false) + ap, ack, err := p.appendShortHeaderPacket(buffer, connID, oneRTTPacketNumber, oneRTTPacketNumberLen, kp, oneRTTPayload, 0, oneRTTSealer, false) if err != nil { return nil, err } - packet.shortHdrPacket = shortHdrPacket + packet.shortHdrPacket = &shortHeaderPacket{ + Packet: ap, + DestConnID: connID, + Ack: ack, + PacketNumberLen: oneRTTPacketNumberLen, + KeyPhase: kp, + } } return packet, nil } @@ -497,11 +509,17 @@ func (p *packetPacker) PackPacket(onlyAck bool, now time.Time) (shortHeaderPacke } kp := sealer.KeyPhase() buffer := getPacketBuffer() - packet, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, payload, 0, sealer, false) + ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, payload, 0, sealer, false) if err != nil { return shortHeaderPacket{}, nil, err } - return *packet, buffer, nil + return shortHeaderPacket{ + Packet: ap, + DestConnID: connID, + Ack: ack, + PacketNumberLen: pnLen, + KeyPhase: kp, + }, buffer, nil } func (p *packetPacker) maybeGetCryptoPacket(maxPacketSize protocol.ByteCount, encLevel protocol.EncryptionLevel, onlyAck, ackAllowed bool) (*wire.ExtendedHeader, *payload) { @@ -697,11 +715,17 @@ func (p *packetPacker) MaybePackProbePacket(encLevel protocol.EncryptionLevel) ( } buffer := getPacketBuffer() packet := &coalescedPacket{buffer: buffer} - shortHdrPacket, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, payload, 0, s, false) + ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, payload, 0, s, false) if err != nil { return nil, err } - packet.shortHdrPacket = shortHdrPacket + packet.shortHdrPacket = &shortHeaderPacket{ + Packet: ap, + DestConnID: connID, + Ack: ack, + PacketNumberLen: pnLen, + KeyPhase: kp, + } return packet, nil } @@ -760,11 +784,18 @@ func (p *packetPacker) PackMTUProbePacket(ping ackhandler.Frame, size protocol.B connID := p.getDestConnID() pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) padding := size - p.shortHeaderPacketLength(connID, pnLen, payload) - protocol.ByteCount(s.Overhead()) - packet, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, s.KeyPhase(), payload, padding, s, true) + kp := s.KeyPhase() + ap, ack, err := p.appendShortHeaderPacket(buffer, connID, pn, pnLen, kp, payload, padding, s, true) if err != nil { return shortHeaderPacket{}, nil, err } - return *packet, buffer, nil + return shortHeaderPacket{ + Packet: ap, + DestConnID: connID, + Ack: ack, + PacketNumberLen: pnLen, + KeyPhase: kp, + }, buffer, nil } func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader { @@ -837,7 +868,7 @@ func (p *packetPacker) appendShortHeaderPacket( padding protocol.ByteCount, sealer sealer, isMTUProbePacket bool, -) (*shortHeaderPacket, error) { +) (*ackhandler.Packet, *wire.AckFrame, error) { var paddingLen protocol.ByteCount if payload.length < 4-protocol.ByteCount(pnLen) { paddingLen = 4 - protocol.ByteCount(pnLen) - payload.length @@ -848,21 +879,21 @@ func (p *packetPacker) appendShortHeaderPacket( raw := buffer.Data[startLen:] raw, err := wire.AppendShortHeader(raw, connID, pn, pnLen, kp) if err != nil { - return nil, err + return nil, nil, err } payloadOffset := protocol.ByteCount(len(raw)) if pn != p.pnManager.PopPacketNumber(protocol.Encryption1RTT) { - return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") + return nil, nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match") } raw, err = p.appendPacketPayload(raw, payload, paddingLen) if err != nil { - return nil, err + return nil, nil, err } if !isMTUProbePacket { if size := protocol.ByteCount(len(raw) + sealer.Overhead()); size > p.maxPacketSize { - return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) + return nil, nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize) } } raw = p.encryptPacket(raw, sealer, pn, payloadOffset, protocol.ByteCount(pnLen)) @@ -889,13 +920,7 @@ func (p *packetPacker) appendShortHeaderPacket( ap.SendTime = time.Now() ap.IsPathMTUProbePacket = isMTUProbePacket - return &shortHeaderPacket{ - Packet: ap, - DestConnID: connID, - Ack: payload.ack, - PacketNumberLen: pnLen, - KeyPhase: kp, - }, nil + return ap, payload.ack, nil } func (p *packetPacker) appendPacketPayload(raw []byte, payload *payload, paddingLen protocol.ByteCount) ([]byte, error) {