diff --git a/packet_packer.go b/packet_packer.go index af5c2b52c..bda585d96 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -33,6 +33,23 @@ func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet { } } +func getMaxPacketSize(addr net.Addr) protocol.ByteCount { + maxSize := protocol.ByteCount(protocol.MinInitialPacketSize) + // If this is not a UDP address, we don't know anything about the MTU. + // Use the minimum size of an Initial packet as the max packet size. + if udpAddr, ok := addr.(*net.UDPAddr); ok { + // If ip is not an IPv4 address, To4 returns nil. + // Note that there might be some corner cases, where this is not correct. + // See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6. + if udpAddr.IP.To4() == nil { + maxSize = protocol.MaxPacketSizeIPv6 + } else { + maxSize = protocol.MaxPacketSizeIPv4 + } + } + return maxSize +} + type sealingManager interface { GetSealer() (protocol.EncryptionLevel, handshake.Sealer) GetSealerForCryptoStream() (protocol.EncryptionLevel, handshake.Sealer) @@ -84,19 +101,6 @@ func newPacketPacker( perspective protocol.Perspective, version protocol.VersionNumber, ) *packetPacker { - maxPacketSize := protocol.ByteCount(protocol.MinInitialPacketSize) - // If this is not a UDP address, we don't know anything about the MTU. - // Use the minimum size of an Initial packet as the max packet size. - if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { - // If ip is not an IPv4 address, To4 returns nil. - // Note that there might be some corner cases, where this is not correct. - // See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6. - if udpAddr.IP.To4() == nil { - maxPacketSize = protocol.MaxPacketSizeIPv6 - } else { - maxPacketSize = protocol.MaxPacketSizeIPv4 - } - } return &packetPacker{ cryptoSetup: cryptoSetup, divNonce: divNonce, @@ -108,7 +112,7 @@ func newPacketPacker( streams: streamFramer, getPacketNumberLen: getPacketNumberLen, packetNumberGenerator: newPacketNumberGenerator(initialPacketNumber, protocol.SkipPacketAveragePeriodLength), - maxPacketSize: maxPacketSize, + maxPacketSize: getMaxPacketSize(remoteAddr), } } diff --git a/packet_packer_test.go b/packet_packer_test.go index 985f99d96..02cba9edd 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -108,25 +108,19 @@ var _ = Describe("Packet packer", func() { }) Context("determining the maximum packet size", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - It("uses the minimum initial size, if it can't determine if the remote address is IPv4 or IPv6", func() { - remoteAddr := &net.TCPAddr{} - packer = newPacketPacker(connID, connID, 1, nil, remoteAddr, nil, nil, nil, nil, protocol.PerspectiveServer, protocol.VersionWhatever) - Expect(packer.maxPacketSize).To(BeEquivalentTo(protocol.MinInitialPacketSize)) + Expect(getMaxPacketSize(&net.TCPAddr{})).To(BeEquivalentTo(protocol.MinInitialPacketSize)) }) It("uses the maximum IPv4 packet size, if the remote address is IPv4", func() { - remoteAddr := &net.UDPAddr{IP: net.IPv4(11, 12, 13, 14), Port: 1337} - packer = newPacketPacker(connID, connID, 1, nil, remoteAddr, nil, nil, nil, nil, protocol.PerspectiveServer, protocol.VersionWhatever) - Expect(packer.maxPacketSize).To(BeEquivalentTo(protocol.MaxPacketSizeIPv4)) + addr := &net.UDPAddr{IP: net.IPv4(11, 12, 13, 14), Port: 1337} + Expect(getMaxPacketSize(addr)).To(BeEquivalentTo(protocol.MaxPacketSizeIPv4)) }) It("uses the maximum IPv6 packet size, if the remote address is IPv6", func() { ip := net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334") - remoteAddr := &net.UDPAddr{IP: ip, Port: 1337} - packer = newPacketPacker(connID, connID, 1, nil, remoteAddr, nil, nil, nil, nil, protocol.PerspectiveServer, protocol.VersionWhatever) - Expect(packer.maxPacketSize).To(BeEquivalentTo(protocol.MaxPacketSizeIPv6)) + addr := &net.UDPAddr{IP: ip, Port: 1337} + Expect(getMaxPacketSize(addr)).To(BeEquivalentTo(protocol.MaxPacketSizeIPv6)) }) })