refactor packet handling functions in the client

This commit is contained in:
Marten Seemann
2018-05-21 12:01:55 +08:00
parent b3fd768a61
commit 97e734e973
2 changed files with 109 additions and 93 deletions

View File

@@ -324,70 +324,84 @@ func (c *client) listen() {
}
break
}
if err := c.handlePacket(addr, data[:n]); err != nil {
c.logger.Errorf("error handling packet: %s", err.Error())
}
c.handleRead(addr, data[:n])
}
}
func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
func (c *client) handleRead(remoteAddr net.Addr, packet []byte) {
rcvTime := time.Now()
r := bytes.NewReader(packet)
hdr, err := wire.ParseHeaderSentByServer(r)
// drop the packet if we can't parse the header
if err != nil {
return fmt.Errorf("error parsing packet from %s: %s", remoteAddr.String(), err.Error())
}
// reject packets with truncated connection id if we didn't request truncation
if hdr.OmitConnectionID && !c.config.RequestConnectionIDOmission {
return errors.New("received packet with truncated connection ID, but didn't request truncation")
c.logger.Errorf("error handling packet: %s", err)
return
}
hdr.Raw = packet[:len(packet)-r.Len()]
packetData := packet[len(packet)-r.Len():]
c.handlePacket(&receivedPacket{
remoteAddr: remoteAddr,
header: hdr,
data: packetData,
rcvTime: rcvTime,
})
}
func (c *client) handlePacket(p *receivedPacket) {
if err := c.handlePacketImpl(p); err != nil {
c.logger.Errorf("error handling packet: %s", err)
}
}
func (c *client) handlePacketImpl(p *receivedPacket) error {
// reject packets with truncated connection id if we didn't request truncation
if p.header.OmitConnectionID && !c.config.RequestConnectionIDOmission {
return errors.New("received packet with truncated connection ID, but didn't request truncation")
}
c.mutex.Lock()
defer c.mutex.Unlock()
// handle Version Negotiation Packets
if hdr.IsVersionNegotiation {
if p.header.IsVersionNegotiation {
// ignore delayed / duplicated version negotiation packets
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
return errors.New("received a delayed Version Negotiation Packet")
}
// version negotiation packets have no payload
if err := c.handleVersionNegotiationPacket(hdr); err != nil {
if err := c.handleVersionNegotiationPacket(p.header); err != nil {
c.session.Close(err)
}
return nil
}
if hdr.IsPublicHeader {
return c.handleGQUICPacket(hdr, r, packetData, remoteAddr, rcvTime)
if p.header.IsPublicHeader {
return c.handleGQUICPacket(p)
}
return c.handleIETFQUICPacket(hdr, packetData, remoteAddr, rcvTime)
return c.handleIETFQUICPacket(p)
}
func (c *client) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
func (c *client) handleIETFQUICPacket(p *receivedPacket) error {
// reject packets with the wrong connection ID
if !hdr.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID)
if !p.header.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
}
if hdr.IsLongHeader {
switch hdr.Type {
if p.header.IsLongHeader {
switch p.header.Type {
case protocol.PacketTypeRetry:
if c.receivedRetry {
return nil
}
case protocol.PacketTypeHandshake:
default:
return fmt.Errorf("Received unsupported packet type: %s", hdr.Type)
return fmt.Errorf("Received unsupported packet type: %s", p.header.Type)
}
if protocol.ByteCount(len(packetData)) < hdr.PayloadLen {
return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(packetData), hdr.PayloadLen)
if protocol.ByteCount(len(p.data)) < p.header.PayloadLen {
return fmt.Errorf("packet payload (%d bytes) is smaller than the expected payload length (%d bytes)", len(p.data), p.header.PayloadLen)
}
packetData = packetData[:int(hdr.PayloadLen)]
p.data = p.data[:int(p.header.PayloadLen)]
// TODO(#1312): implement parsing of compound packets
}
@@ -397,29 +411,24 @@ func (c *client) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remot
c.versionNegotiated = true
}
c.session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr,
header: hdr,
data: packetData,
rcvTime: rcvTime,
})
c.session.handlePacket(p)
return nil
}
func (c *client) handleGQUICPacket(hdr *wire.Header, r *bytes.Reader, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error {
func (c *client) handleGQUICPacket(p *receivedPacket) error {
// reject packets with the wrong connection ID
if !hdr.OmitConnectionID && !hdr.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", hdr.DestConnectionID, c.srcConnID)
if !p.header.OmitConnectionID && !p.header.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
}
if hdr.ResetFlag {
if p.header.ResetFlag {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != remoteAddr.Network() || cr.String() != remoteAddr.String() || !hdr.DestConnectionID.Equal(c.srcConnID) {
if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !p.header.DestConnectionID.Equal(c.srcConnID) {
return errors.New("Received a spoofed Public Reset")
}
pr, err := wire.ParsePublicReset(r)
pr, err := wire.ParsePublicReset(bytes.NewReader(p.data))
if err != nil {
return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err)
}
@@ -434,12 +443,7 @@ func (c *client) handleGQUICPacket(hdr *wire.Header, r *bytes.Reader, packetData
c.versionNegotiated = true
}
c.session.handlePacket(&receivedPacket{
remoteAddr: remoteAddr,
header: hdr,
data: packetData,
rcvTime: rcvTime,
})
c.session.handlePacket(p)
return nil
}