forked from quic-go/quic-go
refactor packet handling functions in the client
This commit is contained in:
116
client_test.go
116
client_test.go
@@ -388,16 +388,13 @@ var _ = Describe("Client", func() {
|
||||
sess := NewMockQuicSession(mockCtrl)
|
||||
sess.EXPECT().handlePacket(gomock.Any())
|
||||
cl.session = sess
|
||||
ph := wire.Header{
|
||||
ph := &wire.Header{
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: protocol.PacketNumberLen2,
|
||||
DestConnectionID: connID,
|
||||
SrcConnectionID: connID,
|
||||
}
|
||||
b := &bytes.Buffer{}
|
||||
err := ph.Write(b, protocol.PerspectiveServer, protocol.VersionWhatever)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = cl.handlePacket(nil, b.Bytes())
|
||||
err := cl.handlePacketImpl(&receivedPacket{header: ph})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cl.versionNegotiated).To(BeTrue())
|
||||
})
|
||||
@@ -439,8 +436,7 @@ var _ = Describe("Client", func() {
|
||||
close(dialed)
|
||||
}()
|
||||
Eventually(sessionChan).Should(HaveLen(1))
|
||||
err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version2}))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version2}))
|
||||
Eventually(sessionChan).Should(BeEmpty())
|
||||
})
|
||||
|
||||
@@ -474,7 +470,7 @@ var _ = Describe("Client", func() {
|
||||
return <-sessionChan, nil
|
||||
}
|
||||
|
||||
cl.config = &Config{Versions: []protocol.VersionNumber{version1, version2}}
|
||||
cl.config = &Config{Versions: []protocol.VersionNumber{version1, version2, version3}}
|
||||
dialed := make(chan struct{})
|
||||
go func() {
|
||||
defer GinkgoRecover()
|
||||
@@ -483,12 +479,12 @@ var _ = Describe("Client", func() {
|
||||
close(dialed)
|
||||
}()
|
||||
Eventually(sessionChan).Should(HaveLen(1))
|
||||
err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version2}))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version2}))
|
||||
Eventually(sessionChan).Should(BeEmpty())
|
||||
err = cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version3}))
|
||||
Expect(err).To(MatchError("received a delayed Version Negotiation Packet"))
|
||||
Expect(cl.version).To(Equal(version2))
|
||||
cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{version3}))
|
||||
Eventually(dialed).Should(BeClosed())
|
||||
Expect(cl.version).To(Equal(version2))
|
||||
})
|
||||
|
||||
It("errors if no matching version is found", func() {
|
||||
@@ -496,8 +492,7 @@ var _ = Describe("Client", func() {
|
||||
sess.EXPECT().Close(gomock.Any())
|
||||
cl.session = sess
|
||||
cl.config = &Config{Versions: protocol.SupportedVersions}
|
||||
err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{1}))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{1}))
|
||||
})
|
||||
|
||||
It("errors if the version is supported by quic-go, but disabled by the quic.Config", func() {
|
||||
@@ -507,8 +502,7 @@ var _ = Describe("Client", func() {
|
||||
v := protocol.VersionNumber(1234)
|
||||
Expect(v).ToNot(Equal(cl.version))
|
||||
cl.config = &Config{Versions: protocol.SupportedVersions}
|
||||
err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{v}))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{v}))
|
||||
})
|
||||
|
||||
It("changes to the version preferred by the quic.Config", func() {
|
||||
@@ -517,15 +511,13 @@ var _ = Describe("Client", func() {
|
||||
cl.session = sess
|
||||
config := &Config{Versions: []protocol.VersionNumber{1234, 4321}}
|
||||
cl.config = config
|
||||
err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{4321, 1234}))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{4321, 1234}))
|
||||
Expect(cl.version).To(Equal(protocol.VersionNumber(1234)))
|
||||
})
|
||||
|
||||
It("drops version negotiation packets that contain the offered version", func() {
|
||||
ver := cl.version
|
||||
err := cl.handlePacket(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{ver}))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.handleRead(nil, wire.ComposeGQUICVersionNegotiation(connID, []protocol.VersionNumber{ver}))
|
||||
Expect(cl.version).To(Equal(ver))
|
||||
})
|
||||
})
|
||||
@@ -533,14 +525,11 @@ var _ = Describe("Client", func() {
|
||||
|
||||
It("ignores packets with an invalid public header", func() {
|
||||
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls
|
||||
err := cl.handlePacket(addr, []byte("invalid packet"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("error parsing packet from"))
|
||||
cl.handleRead(addr, []byte("invalid packet"))
|
||||
})
|
||||
|
||||
It("errors on packets that are smaller than the Payload Length in the packet header", func() {
|
||||
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls
|
||||
b := &bytes.Buffer{}
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
@@ -550,8 +539,12 @@ var _ = Describe("Client", func() {
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
|
||||
cl.handlePacket(addr, append(b.Bytes(), make([]byte, 456)...))
|
||||
err := cl.handlePacketImpl(&receivedPacket{
|
||||
remoteAddr: addr,
|
||||
header: hdr,
|
||||
data: make([]byte, 456),
|
||||
})
|
||||
Expect(err).To(MatchError("received a packet with an unexpected connection ID (0x0102030405060708, expected 0x0000000000001337)"))
|
||||
})
|
||||
|
||||
It("cuts packets at the payload length", func() {
|
||||
@@ -560,7 +553,6 @@ var _ = Describe("Client", func() {
|
||||
Expect(packet.data).To(HaveLen(123))
|
||||
})
|
||||
cl.session = sess
|
||||
b := &bytes.Buffer{}
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeHandshake,
|
||||
@@ -570,13 +562,15 @@ var _ = Describe("Client", func() {
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed())
|
||||
err := cl.handlePacket(addr, append(b.Bytes(), make([]byte, 456)...))
|
||||
err := cl.handlePacketImpl(&receivedPacket{
|
||||
remoteAddr: addr,
|
||||
header: hdr,
|
||||
data: make([]byte, 456),
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("ignores packets with the wrong Long Header Type", func() {
|
||||
b := &bytes.Buffer{}
|
||||
hdr := &wire.Header{
|
||||
IsLongHeader: true,
|
||||
Type: protocol.PacketTypeInitial,
|
||||
@@ -586,43 +580,48 @@ var _ = Describe("Client", func() {
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}
|
||||
Expect(hdr.Write(b, protocol.PerspectiveServer, versionIETFFrames)).To(Succeed())
|
||||
err := cl.handlePacket(addr, append(b.Bytes(), make([]byte, 456)...))
|
||||
err := cl.handlePacketImpl(&receivedPacket{
|
||||
remoteAddr: addr,
|
||||
header: hdr,
|
||||
data: make([]byte, 456),
|
||||
})
|
||||
Expect(err).To(MatchError("Received unsupported packet type: Initial"))
|
||||
})
|
||||
|
||||
It("ignores packets without connection id, if it didn't request connection id trunctation", func() {
|
||||
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls
|
||||
cl.config = &Config{RequestConnectionIDOmission: false}
|
||||
buf := &bytes.Buffer{}
|
||||
err := (&wire.Header{
|
||||
hdr := &wire.Header{
|
||||
OmitConnectionID: true,
|
||||
SrcConnectionID: connID,
|
||||
DestConnectionID: connID,
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
}).Write(buf, protocol.PerspectiveServer, versionGQUICFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = cl.handlePacket(addr, buf.Bytes())
|
||||
PacketNumberLen: 1,
|
||||
}
|
||||
err := cl.handlePacketImpl(&receivedPacket{
|
||||
remoteAddr: addr,
|
||||
header: hdr,
|
||||
})
|
||||
Expect(err).To(MatchError("received packet with truncated connection ID, but didn't request truncation"))
|
||||
})
|
||||
|
||||
It("ignores packets with the wrong destination connection ID", func() {
|
||||
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any handlePacket calls
|
||||
buf := &bytes.Buffer{}
|
||||
cl.version = versionIETFFrames
|
||||
cl.config = &Config{RequestConnectionIDOmission: false}
|
||||
connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
|
||||
Expect(connID).ToNot(Equal(connID2))
|
||||
err := (&wire.Header{
|
||||
hdr := &wire.Header{
|
||||
DestConnectionID: connID2,
|
||||
SrcConnectionID: connID,
|
||||
PacketNumber: 1,
|
||||
PacketNumberLen: protocol.PacketNumberLen1,
|
||||
Version: versionIETFFrames,
|
||||
}).Write(buf, protocol.PerspectiveServer, versionIETFFrames)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = cl.handlePacket(addr, buf.Bytes())
|
||||
}
|
||||
err := cl.handlePacketImpl(&receivedPacket{
|
||||
remoteAddr: addr,
|
||||
header: hdr,
|
||||
})
|
||||
Expect(err).To(MatchError(fmt.Sprintf("received a packet with an unexpected connection ID (0x0807060504030201, expected %s)", connID)))
|
||||
})
|
||||
|
||||
@@ -695,13 +694,13 @@ var _ = Describe("Client", func() {
|
||||
|
||||
It("only accepts one Retry packet", func() {
|
||||
config := &Config{Versions: []protocol.VersionNumber{protocol.VersionTLS}}
|
||||
sess1 := NewMockPacketHandler(mockCtrl)
|
||||
sess1 := NewMockQuicSession(mockCtrl)
|
||||
sess1.EXPECT().run().Return(handshake.ErrCloseSessionForRetry)
|
||||
// don't EXPECT any call to handlePacket()
|
||||
sess2 := NewMockPacketHandler(mockCtrl)
|
||||
sess2 := NewMockQuicSession(mockCtrl)
|
||||
run := make(chan struct{})
|
||||
sess2.EXPECT().run().Do(func() { <-run })
|
||||
sessions := make(chan *MockPacketHandler, 2)
|
||||
sessions := make(chan *MockQuicSession, 2)
|
||||
sessions <- sess1
|
||||
sessions <- sess2
|
||||
newTLSClientSession = func(
|
||||
@@ -716,7 +715,7 @@ var _ = Describe("Client", func() {
|
||||
paramsChan <-chan handshake.TransportParameters,
|
||||
_ protocol.PacketNumber,
|
||||
_ utils.Logger,
|
||||
) (packetHandler, error) {
|
||||
) (quicSession, error) {
|
||||
return <-sessions, nil
|
||||
}
|
||||
|
||||
@@ -795,22 +794,35 @@ var _ = Describe("Client", func() {
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.PublicReset))
|
||||
})
|
||||
cl.session = sess
|
||||
err := cl.handlePacket(addr, wire.WritePublicReset(cl.destConnID, 1, 0))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cl.handleRead(addr, wire.WritePublicReset(cl.destConnID, 1, 0))
|
||||
})
|
||||
|
||||
It("ignores Public Resets from the wrong remote address", func() {
|
||||
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any calls
|
||||
spoofedAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 5678}
|
||||
err := cl.handlePacket(spoofedAddr, wire.WritePublicReset(cl.destConnID, 1, 0))
|
||||
pr := wire.WritePublicReset(cl.destConnID, 1, 0)
|
||||
r := bytes.NewReader(pr)
|
||||
hdr, err := wire.ParseHeaderSentByServer(r)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = cl.handlePacketImpl(&receivedPacket{
|
||||
remoteAddr: spoofedAddr,
|
||||
header: hdr,
|
||||
data: pr[len(pr)-r.Len():],
|
||||
})
|
||||
Expect(err).To(MatchError("Received a spoofed Public Reset"))
|
||||
})
|
||||
|
||||
It("ignores unparseable Public Resets", func() {
|
||||
cl.session = NewMockQuicSession(mockCtrl) // don't EXPECT any calls
|
||||
pr := wire.WritePublicReset(cl.destConnID, 1, 0)
|
||||
err := cl.handlePacket(addr, pr[:len(pr)-5])
|
||||
Expect(err).To(HaveOccurred())
|
||||
r := bytes.NewReader(pr)
|
||||
hdr, err := wire.ParseHeaderSentByServer(r)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = cl.handlePacketImpl(&receivedPacket{
|
||||
remoteAddr: addr,
|
||||
header: hdr,
|
||||
data: pr[len(pr)-r.Len() : len(pr)-5], // cut off the last 5 bytes
|
||||
})
|
||||
Expect(err.Error()).To(ContainSubstring("Received a Public Reset. An error occurred parsing the packet"))
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user