diff --git a/client_test.go b/client_test.go index 6f8e8d08..f6995efb 100644 --- a/client_test.go +++ b/client_test.go @@ -1,7 +1,6 @@ package quic import ( - "bytes" "context" "crypto/tls" "errors" @@ -12,7 +11,6 @@ import ( mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" "github.com/lucas-clemente/quic-go/logging" "github.com/lucas-clemente/quic-go/quictrace" @@ -25,7 +23,7 @@ import ( var _ = Describe("Client", func() { var ( cl *client - packetConn *mockPacketConn + packetConn *MockPacketConn addr net.Addr connID protocol.ConnectionID mockMultiplexer *MockMultiplexer @@ -51,17 +49,6 @@ var _ = Describe("Client", func() { ) quicSession ) - // generate a packet sent by the server that accepts the QUIC version suggested by the client - acceptClientVersionPacket := func(connID protocol.ConnectionID) []byte { - b := &bytes.Buffer{} - Expect((&wire.ExtendedHeader{ - Header: wire.Header{DestConnectionID: connID}, - PacketNumber: 1, - PacketNumberLen: 1, - }).Write(b, protocol.VersionWhatever)).To(Succeed()) - return b.Bytes() - } - BeforeEach(func() { tlsConf = &tls.Config{NextProtos: []string{"proto1"}} connID = protocol.ConnectionID{0, 0, 0, 0, 0, 0, 0x13, 0x37} @@ -73,9 +60,8 @@ var _ = Describe("Client", func() { Eventually(areSessionsRunning).Should(BeFalse()) // sess = NewMockQuicSession(mockCtrl) addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} - packetConn = newMockPacketConn() - packetConn.addr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} - packetConn.dataReadFrom = addr + packetConn = NewMockPacketConn(mockCtrl) + packetConn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() cl = &client{ srcConnID: connID, destConnID: connID, @@ -221,7 +207,7 @@ var _ = Describe("Client", func() { sess.EXPECT().run() return sess } - tracer.EXPECT().StartedConnection(packetConn.addr, addr, protocol.VersionTLS, gomock.Any(), gomock.Any()) + tracer.EXPECT().StartedConnection(packetConn.LocalAddr(), addr, protocol.VersionTLS, gomock.Any(), gomock.Any()) _, err := Dial( packetConn, addr, @@ -350,7 +336,6 @@ var _ = Describe("Client", func() { sess.EXPECT().HandshakeComplete().Return(context.Background()) return sess } - packetConn.dataToRead <- acceptClientVersionPacket(cl.srcConnID) tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), protocol.VersionTLS, gomock.Any(), gomock.Any()) _, err := Dial( packetConn, diff --git a/conn_test.go b/conn_test.go index 1043535e..b4b39ae1 100644 --- a/conn_test.go +++ b/conn_test.go @@ -4,16 +4,23 @@ import ( "net" "time" + "github.com/lucas-clemente/quic-go/internal/protocol" + + "github.com/golang/mock/gomock" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("Basic Conn Test", func() { It("reads a packet", func() { - c := newMockPacketConn() + c := NewMockPacketConn(mockCtrl) addr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234} - c.dataReadFrom = addr - c.dataToRead <- []byte("foobar") + c.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) { + data := []byte("foobar") + Expect(b).To(HaveLen(int(protocol.MaxReceivePacketSize))) + return copy(b, data), addr, nil + }) conn, err := wrapConn(c) Expect(err).ToNot(HaveOccurred()) diff --git a/mock_packetconn_test.go b/mock_packetconn_test.go new file mode 100644 index 00000000..e3fe28a8 --- /dev/null +++ b/mock_packetconn_test.go @@ -0,0 +1,137 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: net (interfaces: PacketConn) + +// Package quic is a generated GoMock package. +package quic + +import ( + net "net" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" +) + +// MockPacketConn is a mock of PacketConn interface +type MockPacketConn struct { + ctrl *gomock.Controller + recorder *MockPacketConnMockRecorder +} + +// MockPacketConnMockRecorder is the mock recorder for MockPacketConn +type MockPacketConnMockRecorder struct { + mock *MockPacketConn +} + +// NewMockPacketConn creates a new mock instance +func NewMockPacketConn(ctrl *gomock.Controller) *MockPacketConn { + mock := &MockPacketConn{ctrl: ctrl} + mock.recorder = &MockPacketConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockPacketConn) EXPECT() *MockPacketConnMockRecorder { + return m.recorder +} + +// Close mocks base method +func (m *MockPacketConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockPacketConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketConn)(nil).Close)) +} + +// LocalAddr mocks base method +func (m *MockPacketConn) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr +func (mr *MockPacketConnMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockPacketConn)(nil).LocalAddr)) +} + +// ReadFrom mocks base method +func (m *MockPacketConn) ReadFrom(arg0 []byte) (int, net.Addr, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadFrom", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(net.Addr) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ReadFrom indicates an expected call of ReadFrom +func (mr *MockPacketConnMockRecorder) ReadFrom(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadFrom", reflect.TypeOf((*MockPacketConn)(nil).ReadFrom), arg0) +} + +// SetDeadline mocks base method +func (m *MockPacketConn) SetDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline +func (mr *MockPacketConnMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetDeadline), arg0) +} + +// SetReadDeadline mocks base method +func (m *MockPacketConn) SetReadDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline +func (mr *MockPacketConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetReadDeadline), arg0) +} + +// SetWriteDeadline mocks base method +func (m *MockPacketConn) SetWriteDeadline(arg0 time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWriteDeadline", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWriteDeadline indicates an expected call of SetWriteDeadline +func (mr *MockPacketConnMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockPacketConn)(nil).SetWriteDeadline), arg0) +} + +// WriteTo mocks base method +func (m *MockPacketConn) WriteTo(arg0 []byte, arg1 net.Addr) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WriteTo", arg0, arg1) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// WriteTo indicates an expected call of WriteTo +func (mr *MockPacketConnMockRecorder) WriteTo(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteTo", reflect.TypeOf((*MockPacketConn)(nil).WriteTo), arg0, arg1) +} diff --git a/mockgen.go b/mockgen.go index 89574c89..7f0ba432 100644 --- a/mockgen.go +++ b/mockgen.go @@ -21,3 +21,4 @@ package quic //go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_manager_test.go github.com/lucas-clemente/quic-go packetHandlerManager" //go:generate sh -c "./mockgen_private.sh quic mock_multiplexer_test.go github.com/lucas-clemente/quic-go multiplexer" //go:generate sh -c "mockgen -package quic -self_package github.com/lucas-clemente/quic-go -destination mock_token_store_test.go github.com/lucas-clemente/quic-go TokenStore && goimports -w mock_token_store_test.go" +//go:generate sh -c "mockgen -package quic -self_package github.com/lucas-clemente/quic-go -destination mock_packetconn_test.go net PacketConn && goimports -w mock_packetconn_test.go" diff --git a/multiplexer.go b/multiplexer.go index 129db41f..006305af 100644 --- a/multiplexer.go +++ b/multiplexer.go @@ -64,7 +64,8 @@ func (m *connMultiplexer) AddConn( m.mutex.Lock() defer m.mutex.Unlock() - connIndex := c.LocalAddr().Network() + " " + c.LocalAddr().String() + addr := c.LocalAddr() + connIndex := addr.Network() + " " + addr.String() p, ok := m.conns[connIndex] if !ok { manager, err := m.newPacketHandlerManager(c, connIDLen, statelessResetKey, tracer, m.logger) diff --git a/multiplexer_test.go b/multiplexer_test.go index 5faa701e..84332a07 100644 --- a/multiplexer_test.go +++ b/multiplexer_test.go @@ -3,6 +3,7 @@ package quic import ( "net" + "github.com/golang/mock/gomock" mocklogging "github.com/lucas-clemente/quic-go/internal/mocks/logging" . "github.com/onsi/ginkgo" @@ -14,16 +15,19 @@ type testConn struct { net.PacketConn } -var _ = Describe("Client Multiplexer", func() { +var _ = Describe("Multiplexer", func() { It("adds a new packet conn ", func() { - conn := newMockPacketConn() + conn := NewMockPacketConn(mockCtrl) + conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) + conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}) _, err := getMultiplexer().AddConn(conn, 8, nil, nil) Expect(err).ToNot(HaveOccurred()) }) It("recognizes when the same connection is added twice", func() { - pconn := newMockPacketConn() - pconn.addr = &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321} + pconn := NewMockPacketConn(mockCtrl) + pconn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2) + pconn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) conn := testConn{PacketConn: pconn} tracer := mocklogging.NewMockTracer(mockCtrl) _, err := getMultiplexer().AddConn(conn, 8, []byte("foobar"), tracer) @@ -35,7 +39,9 @@ var _ = Describe("Client Multiplexer", func() { }) It("errors when adding an existing conn with a different connection ID length", func() { - conn := newMockPacketConn() + conn := NewMockPacketConn(mockCtrl) + conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) + conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2) _, err := getMultiplexer().AddConn(conn, 5, nil, nil) Expect(err).ToNot(HaveOccurred()) _, err = getMultiplexer().AddConn(conn, 6, nil, nil) @@ -43,7 +49,9 @@ var _ = Describe("Client Multiplexer", func() { }) It("errors when adding an existing conn with a different stateless rest key", func() { - conn := newMockPacketConn() + conn := NewMockPacketConn(mockCtrl) + conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) + conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2) _, err := getMultiplexer().AddConn(conn, 7, []byte("foobar"), nil) Expect(err).ToNot(HaveOccurred()) _, err = getMultiplexer().AddConn(conn, 7, []byte("raboof"), nil) @@ -51,7 +59,9 @@ var _ = Describe("Client Multiplexer", func() { }) It("errors when adding an existing conn with different tracers", func() { - conn := newMockPacketConn() + conn := NewMockPacketConn(mockCtrl) + conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) + conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2) _, err := getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl)) Expect(err).ToNot(HaveOccurred()) _, err = getMultiplexer().AddConn(conn, 7, nil, mocklogging.NewMockTracer(mockCtrl)) diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 13d1d312..dd2895dc 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -20,10 +20,17 @@ import ( ) var _ = Describe("Packet Handler Map", func() { + type packetToRead struct { + addr net.Addr + data []byte + err error + } + var ( - handler *packetHandlerMap - conn *mockPacketConn - tracer *mocklogging.MockTracer + handler *packetHandlerMap + conn *MockPacketConn + tracer *mocklogging.MockTracer + packetChan chan packetToRead connIDLen int statelessResetKey []byte @@ -52,28 +59,24 @@ var _ = Describe("Packet Handler Map", func() { statelessResetKey = nil connIDLen = 0 tracer = mocklogging.NewMockTracer(mockCtrl) + packetChan = make(chan packetToRead, 10) }) JustBeforeEach(func() { - conn = newMockPacketConn() + conn = NewMockPacketConn(mockCtrl) + conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() + conn.EXPECT().ReadFrom(gomock.Any()).DoAndReturn(func(b []byte) (int, net.Addr, error) { + p, ok := <-packetChan + if !ok { + return 0, nil, errors.New("closed") + } + return copy(b, p.data), p.addr, p.err + }).AnyTimes() phm, err := newPacketHandlerMap(conn, connIDLen, statelessResetKey, tracer, utils.DefaultLogger) Expect(err).ToNot(HaveOccurred()) handler = phm.(*packetHandlerMap) }) - AfterEach(func() { - // delete sessions and the server before closing - // They might be mock implementations, and we'd have to register the expected calls before otherwise. - handler.mutex.Lock() - for connID := range handler.handlers { - delete(handler.handlers, connID) - } - handler.server = nil - handler.mutex.Unlock() - handler.Destroy() - Eventually(handler.listening).Should(BeClosed()) - }) - It("closes", func() { getMultiplexer() // make the sync.Once execute // replace the clientMuxer. getClientMultiplexer will now return the MockMultiplexer @@ -94,284 +97,307 @@ var _ = Describe("Packet Handler Map", func() { handler.Add(protocol.ConnectionID{2, 2, 2, 2}, sess2) mockMultiplexer.EXPECT().RemoveConn(gomock.Any()) handler.close(testErr) + close(packetChan) + Eventually(handler.listening).Should(BeClosed()) }) - Context("handling packets", func() { - BeforeEach(func() { - connIDLen = 5 + Context("other operations", func() { + AfterEach(func() { + // delete sessions and the server before closing + // They might be mock implementations, and we'd have to register the expected calls before otherwise. + handler.mutex.Lock() + for connID := range handler.handlers { + delete(handler.handlers, connID) + } + handler.server = nil + handler.mutex.Unlock() + conn.EXPECT().Close().MaxTimes(1) + close(packetChan) + handler.Destroy() + Eventually(handler.listening).Should(BeClosed()) }) - It("handles packets for different packet handlers on the same packet conn", func() { - connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} - packetHandler1 := NewMockPacketHandler(mockCtrl) - packetHandler2 := NewMockPacketHandler(mockCtrl) - handledPacket1 := make(chan struct{}) - handledPacket2 := make(chan struct{}) - packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - connID, err := wire.ParseConnectionID(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(connID1)) - close(handledPacket1) + Context("handling packets", func() { + BeforeEach(func() { + connIDLen = 5 }) - packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - connID, err := wire.ParseConnectionID(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(connID).To(Equal(connID2)) - close(handledPacket2) - }) - handler.Add(connID1, packetHandler1) - handler.Add(connID2, packetHandler2) - conn.dataToRead <- getPacket(connID1) - conn.dataToRead <- getPacket(connID2) - Eventually(handledPacket1).Should(BeClosed()) - Eventually(handledPacket2).Should(BeClosed()) - }) - - It("drops unparseable packets", func() { - addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} - tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError) - handler.handlePacket(&receivedPacket{ - buffer: getPacketBuffer(), - remoteAddr: addr, - data: []byte{0, 1, 2, 3}, - }) - }) - - It("deletes removed sessions immediately", func() { - handler.deleteRetiredSessionsAfter = time.Hour - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - handler.Add(connID, NewMockPacketHandler(mockCtrl)) - handler.Remove(connID) - handler.handlePacket(&receivedPacket{data: getPacket(connID)}) - // don't EXPECT any calls to handlePacket of the MockPacketHandler - }) - - It("deletes retired session entries after a wait time", func() { - handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - sess := NewMockPacketHandler(mockCtrl) - handler.Add(connID, sess) - handler.Retire(connID) - time.Sleep(scaleDuration(30 * time.Millisecond)) - handler.handlePacket(&receivedPacket{data: getPacket(connID)}) - // don't EXPECT any calls to handlePacket of the MockPacketHandler - }) - - It("passes packets arriving late for closed sessions to that session", func() { - handler.deleteRetiredSessionsAfter = time.Hour - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - packetHandler := NewMockPacketHandler(mockCtrl) - handled := make(chan struct{}) - packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - close(handled) - }) - handler.Add(connID, packetHandler) - handler.Retire(connID) - handler.handlePacket(&receivedPacket{data: getPacket(connID)}) - Eventually(handled).Should(BeClosed()) - }) - - It("drops packets for unknown receivers", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - handler.handlePacket(&receivedPacket{data: getPacket(connID)}) - }) - - It("closes the packet handlers when reading from the conn fails", func() { - done := make(chan struct{}) - packetHandler := NewMockPacketHandler(mockCtrl) - packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) { - Expect(e).To(HaveOccurred()) - close(done) - }) - handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler) - conn.Close() - Eventually(done).Should(BeClosed()) - }) - - It("says if a connection ID is already taken", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue()) - Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse()) - }) - - It("says if a connection ID is already taken, for AddWithConnID", func() { - clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - newConnID1 := protocol.ConnectionID{1, 2, 3, 4} - newConnID2 := protocol.ConnectionID{4, 3, 2, 1} - Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue()) - Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse()) - }) - }) - - Context("running a server", func() { - It("adds a server", func() { - connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} - p := getPacket(connID) - server := NewMockUnknownPacketHandler(mockCtrl) - server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - cid, err := wire.ParseConnectionID(p.data, 0) - Expect(err).ToNot(HaveOccurred()) - Expect(cid).To(Equal(connID)) - }) - handler.SetServer(server) - handler.handlePacket(&receivedPacket{data: p}) - }) - - It("closes all server sessions", func() { - clientSess := NewMockPacketHandler(mockCtrl) - clientSess.EXPECT().getPerspective().Return(protocol.PerspectiveClient) - serverSess := NewMockPacketHandler(mockCtrl) - serverSess.EXPECT().getPerspective().Return(protocol.PerspectiveServer) - serverSess.EXPECT().shutdown() - - handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientSess) - handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverSess) - handler.CloseServer() - }) - - It("stops handling packets with unknown connection IDs after the server is closed", func() { - connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} - p := getPacket(connID) - server := NewMockUnknownPacketHandler(mockCtrl) - // don't EXPECT any calls to server.handlePacket - handler.SetServer(server) - handler.CloseServer() - handler.handlePacket(&receivedPacket{data: p}) - }) - }) - - Context("stateless resets", func() { - BeforeEach(func() { - connIDLen = 5 - }) - - Context("handling", func() { - It("handles stateless resets", func() { - packetHandler := NewMockPacketHandler(mockCtrl) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, packetHandler) - packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) - packet = append(packet, token[:]...) - destroyed := make(chan struct{}) - packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { - defer GinkgoRecover() - Expect(err).To(HaveOccurred()) - var resetErr statelessResetErr - Expect(errors.As(err, &resetErr)).To(BeTrue()) - Expect(err.Error()).To(ContainSubstring("received a stateless reset")) - Expect(resetErr.token).To(Equal(token)) - close(destroyed) + It("handles packets for different packet handlers on the same packet conn", func() { + connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} + packetHandler1 := NewMockPacketHandler(mockCtrl) + packetHandler2 := NewMockPacketHandler(mockCtrl) + handledPacket1 := make(chan struct{}) + handledPacket2 := make(chan struct{}) + packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + connID, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(connID1)) + close(handledPacket1) }) - conn.dataToRead <- packet - Eventually(destroyed).Should(BeClosed()) - }) - - It("handles stateless resets for 0-length connection IDs", func() { - handler.connIDLen = 0 - packetHandler := NewMockPacketHandler(mockCtrl) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, packetHandler) - packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) - packet = append(packet, token[:]...) - destroyed := make(chan struct{}) - packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { - defer GinkgoRecover() - Expect(err).To(HaveOccurred()) - var resetErr statelessResetErr - Expect(errors.As(err, &resetErr)).To(BeTrue()) - Expect(err.Error()).To(ContainSubstring("received a stateless reset")) - Expect(resetErr.token).To(Equal(token)) - close(destroyed) + packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + connID, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(connID2)) + close(handledPacket2) }) - conn.dataToRead <- packet - Eventually(destroyed).Should(BeClosed()) + handler.Add(connID1, packetHandler1) + handler.Add(connID2, packetHandler2) + packetChan <- packetToRead{data: getPacket(connID1)} + packetChan <- packetToRead{data: getPacket(connID2)} + + Eventually(handledPacket1).Should(BeClosed()) + Eventually(handledPacket2).Should(BeClosed()) }) - It("retires reset tokens", func() { + It("drops unparseable packets", func() { + addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} + tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError) + handler.handlePacket(&receivedPacket{ + buffer: getPacketBuffer(), + remoteAddr: addr, + data: []byte{0, 1, 2, 3}, + }) + }) + + It("deletes removed sessions immediately", func() { + handler.deleteRetiredSessionsAfter = time.Hour + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + handler.Add(connID, NewMockPacketHandler(mockCtrl)) + handler.Remove(connID) + handler.handlePacket(&receivedPacket{data: getPacket(connID)}) + // don't EXPECT any calls to handlePacket of the MockPacketHandler + }) + + It("deletes retired session entries after a wait time", func() { handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) - connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} - packetHandler := NewMockPacketHandler(mockCtrl) - handler.Add(connID, packetHandler) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, NewMockPacketHandler(mockCtrl)) - handler.RetireResetToken(token) - packetHandler.EXPECT().handlePacket(gomock.Any()) - p := append([]byte{0x40} /* short header packet */, connID.Bytes()...) - p = append(p, make([]byte, 50)...) - p = append(p, token[:]...) - + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + sess := NewMockPacketHandler(mockCtrl) + handler.Add(connID, sess) + handler.Retire(connID) time.Sleep(scaleDuration(30 * time.Millisecond)) + handler.handlePacket(&receivedPacket{data: getPacket(connID)}) + // don't EXPECT any calls to handlePacket of the MockPacketHandler + }) + + It("passes packets arriving late for closed sessions to that session", func() { + handler.deleteRetiredSessionsAfter = time.Hour + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + packetHandler := NewMockPacketHandler(mockCtrl) + handled := make(chan struct{}) + packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + close(handled) + }) + handler.Add(connID, packetHandler) + handler.Retire(connID) + handler.handlePacket(&receivedPacket{data: getPacket(connID)}) + Eventually(handled).Should(BeClosed()) + }) + + It("drops packets for unknown receivers", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + handler.handlePacket(&receivedPacket{data: getPacket(connID)}) + }) + + It("closes the packet handlers when reading from the conn fails", func() { + done := make(chan struct{}) + packetHandler := NewMockPacketHandler(mockCtrl) + packetHandler.EXPECT().destroy(gomock.Any()).Do(func(e error) { + Expect(e).To(HaveOccurred()) + close(done) + }) + handler.Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, packetHandler) + packetChan <- packetToRead{err: errors.New("read failed")} + Eventually(done).Should(BeClosed()) + }) + + It("says if a connection ID is already taken", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeTrue()) + Expect(handler.Add(connID, NewMockPacketHandler(mockCtrl))).To(BeFalse()) + }) + + It("says if a connection ID is already taken, for AddWithConnID", func() { + clientDestConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} + newConnID1 := protocol.ConnectionID{1, 2, 3, 4} + newConnID2 := protocol.ConnectionID{4, 3, 2, 1} + Expect(handler.AddWithConnID(clientDestConnID, newConnID1, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeTrue()) + Expect(handler.AddWithConnID(clientDestConnID, newConnID2, func() packetHandler { return NewMockPacketHandler(mockCtrl) })).To(BeFalse()) + }) + }) + + Context("running a server", func() { + It("adds a server", func() { + connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + p := getPacket(connID) + server := NewMockUnknownPacketHandler(mockCtrl) + server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { + cid, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(cid).To(Equal(connID)) + }) + handler.SetServer(server) handler.handlePacket(&receivedPacket{data: p}) }) - It("ignores packets too small to contain a stateless reset", func() { - handler.connIDLen = 0 - packetHandler := NewMockPacketHandler(mockCtrl) - token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddResetToken(token, packetHandler) - packet := append([]byte{0x40} /* short header packet */, token[:15]...) - done := make(chan struct{}) - // don't EXPECT any calls here, but register the closing of the done channel - packetHandler.EXPECT().destroy(gomock.Any()).Do(func(error) { - close(done) - }).AnyTimes() - conn.dataToRead <- packet - Consistently(done).ShouldNot(BeClosed()) + It("closes all server sessions", func() { + clientSess := NewMockPacketHandler(mockCtrl) + clientSess.EXPECT().getPerspective().Return(protocol.PerspectiveClient) + serverSess := NewMockPacketHandler(mockCtrl) + serverSess.EXPECT().getPerspective().Return(protocol.PerspectiveServer) + serverSess.EXPECT().shutdown() + + handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientSess) + handler.Add(protocol.ConnectionID{2, 2, 2, 2}, serverSess) + handler.CloseServer() + }) + + It("stops handling packets with unknown connection IDs after the server is closed", func() { + connID := protocol.ConnectionID{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + p := getPacket(connID) + server := NewMockUnknownPacketHandler(mockCtrl) + // don't EXPECT any calls to server.handlePacket + handler.SetServer(server) + handler.CloseServer() + handler.handlePacket(&receivedPacket{data: p}) }) }) - Context("generating", func() { + Context("stateless resets", func() { BeforeEach(func() { - key := make([]byte, 32) - rand.Read(key) - statelessResetKey = key + connIDLen = 5 }) - It("generates stateless reset tokens", func() { - connID1 := []byte{0xde, 0xad, 0xbe, 0xef} - connID2 := []byte{0xde, 0xca, 0xfb, 0xad} - Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2))) - }) - - It("sends stateless resets", func() { - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - p := append([]byte{40}, make([]byte, 100)...) - handler.handlePacket(&receivedPacket{ - buffer: getPacketBuffer(), - remoteAddr: addr, - data: p, + Context("handling", func() { + It("handles stateless resets", func() { + packetHandler := NewMockPacketHandler(mockCtrl) + token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddResetToken(token, packetHandler) + destroyed := make(chan struct{}) + packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) + packet = append(packet, token[:]...) + packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { + defer GinkgoRecover() + defer close(destroyed) + Expect(err).To(HaveOccurred()) + var resetErr statelessResetErr + Expect(errors.As(err, &resetErr)).To(BeTrue()) + Expect(err.Error()).To(ContainSubstring("received a stateless reset")) + Expect(resetErr.token).To(Equal(token)) + }) + packetChan <- packetToRead{data: packet} + Eventually(destroyed).Should(BeClosed()) + time.Sleep(time.Second) + }) + + It("handles stateless resets for 0-length connection IDs", func() { + handler.connIDLen = 0 + packetHandler := NewMockPacketHandler(mockCtrl) + token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddResetToken(token, packetHandler) + destroyed := make(chan struct{}) + packet := append([]byte{0x40} /* short header packet */, make([]byte, 50)...) + packet = append(packet, token[:]...) + packetHandler.EXPECT().destroy(gomock.Any()).Do(func(err error) { + defer GinkgoRecover() + Expect(err).To(HaveOccurred()) + var resetErr statelessResetErr + Expect(errors.As(err, &resetErr)).To(BeTrue()) + Expect(err.Error()).To(ContainSubstring("received a stateless reset")) + Expect(resetErr.token).To(Equal(token)) + close(destroyed) + }) + packetChan <- packetToRead{data: packet} + Eventually(destroyed).Should(BeClosed()) + }) + + It("retires reset tokens", func() { + handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) + connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} + packetHandler := NewMockPacketHandler(mockCtrl) + handler.Add(connID, packetHandler) + token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddResetToken(token, NewMockPacketHandler(mockCtrl)) + handler.RetireResetToken(token) + packetHandler.EXPECT().handlePacket(gomock.Any()) + p := append([]byte{0x40} /* short header packet */, connID.Bytes()...) + p = append(p, make([]byte, 50)...) + p = append(p, token[:]...) + + time.Sleep(scaleDuration(30 * time.Millisecond)) + handler.handlePacket(&receivedPacket{data: p}) + }) + + It("ignores packets too small to contain a stateless reset", func() { + handler.connIDLen = 0 + packetHandler := NewMockPacketHandler(mockCtrl) + token := protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + handler.AddResetToken(token, packetHandler) + done := make(chan struct{}) + // don't EXPECT any calls here, but register the closing of the done channel + packetHandler.EXPECT().destroy(gomock.Any()).Do(func(error) { + close(done) + }).AnyTimes() + packetChan <- packetToRead{data: append([]byte{0x40} /* short header packet */, token[:15]...)} + Consistently(done).ShouldNot(BeClosed()) }) - var reset mockPacketConnWrite - Eventually(conn.dataWritten).Should(Receive(&reset)) - Expect(reset.to).To(Equal(addr)) - Expect(reset.data[0] & 0x80).To(BeZero()) // short header packet - Expect(reset.data).To(HaveLen(protocol.MinStatelessResetSize)) }) - It("doesn't send stateless resets for small packets", func() { - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...) - handler.handlePacket(&receivedPacket{ - buffer: getPacketBuffer(), - remoteAddr: addr, - data: p, + Context("generating", func() { + BeforeEach(func() { + key := make([]byte, 32) + rand.Read(key) + statelessResetKey = key }) - Consistently(conn.dataWritten).ShouldNot(Receive()) - }) - }) - Context("if no key is configured", func() { - It("doesn't send stateless resets", func() { - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} - p := append([]byte{40}, make([]byte, 100)...) - handler.handlePacket(&receivedPacket{ - buffer: getPacketBuffer(), - remoteAddr: addr, - data: p, + It("generates stateless reset tokens", func() { + connID1 := []byte{0xde, 0xad, 0xbe, 0xef} + connID2 := []byte{0xde, 0xca, 0xfb, 0xad} + Expect(handler.GetStatelessResetToken(connID1)).ToNot(Equal(handler.GetStatelessResetToken(connID2))) + }) + + It("sends stateless resets", func() { + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + p := append([]byte{40}, make([]byte, 100)...) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), addr).Do(func(b []byte, _ net.Addr) { + defer close(done) + Expect(b[0] & 0x80).To(BeZero()) // short header packet + Expect(b).To(HaveLen(protocol.MinStatelessResetSize)) + }) + handler.handlePacket(&receivedPacket{ + buffer: getPacketBuffer(), + remoteAddr: addr, + data: p, + }) + Eventually(done).Should(BeClosed()) + }) + + It("doesn't send stateless resets for small packets", func() { + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + p := append([]byte{40}, make([]byte, protocol.MinStatelessResetSize-2)...) + handler.handlePacket(&receivedPacket{ + buffer: getPacketBuffer(), + remoteAddr: addr, + data: p, + }) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) + }) + }) + + Context("if no key is configured", func() { + It("doesn't send stateless resets", func() { + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337} + p := append([]byte{40}, make([]byte, 100)...) + handler.handlePacket(&receivedPacket{ + buffer: getPacketBuffer(), + remoteAddr: addr, + data: p, + }) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) }) - Consistently(conn.dataWritten).ShouldNot(Receive()) }) }) }) diff --git a/send_conn_test.go b/send_conn_test.go index 570b7e91..5100963c 100644 --- a/send_conn_test.go +++ b/send_conn_test.go @@ -1,90 +1,28 @@ package quic import ( - "errors" "net" - "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) -type mockPacketConnWrite struct { - data []byte - to net.Addr -} - -type mockPacketConn struct { - addr net.Addr - dataToRead chan []byte - dataReadFrom net.Addr - readErr error - dataWritten chan mockPacketConnWrite - closed bool -} - -func newMockPacketConn() *mockPacketConn { - return &mockPacketConn{ - addr: &net.UDPAddr{IP: net.IPv6zero, Port: 0x42}, - dataToRead: make(chan []byte, 1000), - dataWritten: make(chan mockPacketConnWrite, 1000), - } -} - -func (c *mockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { - if c.readErr != nil { - return 0, nil, c.readErr - } - data, ok := <-c.dataToRead - if !ok { - return 0, nil, errors.New("connection closed") - } - n := copy(b, data) - return n, c.dataReadFrom, nil -} - -func (c *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - select { - case c.dataWritten <- mockPacketConnWrite{to: addr, data: b}: - return len(b), nil - default: - panic("channel full") - } -} - -func (c *mockPacketConn) Close() error { - if !c.closed { - close(c.dataToRead) - } - c.closed = true - return nil -} -func (c *mockPacketConn) LocalAddr() net.Addr { return c.addr } -func (c *mockPacketConn) SetDeadline(t time.Time) error { panic("not implemented") } -func (c *mockPacketConn) SetReadDeadline(t time.Time) error { panic("not implemented") } -func (c *mockPacketConn) SetWriteDeadline(t time.Time) error { panic("not implemented") } - -var _ net.PacketConn = &mockPacketConn{} - -var _ = Describe("Send-Connection", func() { - var c sendConn - var packetConn *mockPacketConn +var _ = Describe("Connection (for sending packets)", func() { + var ( + c sendConn + packetConn *MockPacketConn + addr net.Addr + ) BeforeEach(func() { - addr := &net.UDPAddr{ - IP: net.IPv4(192, 168, 100, 200), - Port: 1337, - } - packetConn = newMockPacketConn() + addr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337} + packetConn = NewMockPacketConn(mockCtrl) c = newSendConn(packetConn, addr) }) It("writes", func() { + packetConn.EXPECT().WriteTo([]byte("foobar"), addr) Expect(c.Write([]byte("foobar"))).To(Succeed()) - var write mockPacketConnWrite - Expect(packetConn.dataWritten).To(Receive(&write)) - Expect(write.to.String()).To(Equal("192.168.100.200:1337")) - Expect(write.data).To(Equal([]byte("foobar"))) }) It("gets the remote address", func() { @@ -96,13 +34,12 @@ var _ = Describe("Send-Connection", func() { IP: net.IPv4(192, 168, 0, 1), Port: 1234, } - packetConn.addr = addr + packetConn.EXPECT().LocalAddr().Return(addr) Expect(c.LocalAddr()).To(Equal(addr)) }) It("closes", func() { - err := c.Close() - Expect(err).ToNot(HaveOccurred()) - Expect(packetConn.closed).To(BeTrue()) + packetConn.EXPECT().Close() + Expect(c.Close()).To(Succeed()) }) }) diff --git a/server_test.go b/server_test.go index a2a9278a..d0c08cf7 100644 --- a/server_test.go +++ b/server_test.go @@ -38,7 +38,7 @@ func areServersRunning() bool { var _ = Describe("Server", func() { var ( - conn *mockPacketConn + conn *MockPacketConn tlsConf *tls.Config ) @@ -97,8 +97,9 @@ var _ = Describe("Server", func() { } BeforeEach(func() { - conn = newMockPacketConn() - conn.addr = &net.UDPAddr{} + conn = NewMockPacketConn(mockCtrl) + conn.EXPECT().LocalAddr().Return(&net.UDPAddr{}).AnyTimes() + conn.EXPECT().ReadFrom(gomock.Any()).Do(func(_ []byte) { <-(make(chan struct{})) }).MaxTimes(1) tlsConf = testdata.GetTLSConfig() tlsConf.NextProtos = []string{"proto1"} }) @@ -212,7 +213,8 @@ var _ = Describe("Server", func() { }, nil) tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) serv.handlePacket(p) - Consistently(conn.dataWritten).ShouldNot(Receive()) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) }) It("drops too small Initial", func() { @@ -225,7 +227,8 @@ var _ = Describe("Server", func() { ) tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropUnexpectedPacket) serv.handlePacket(p) - Consistently(conn.dataWritten).ShouldNot(Receive()) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) }) It("drops non-Initial packets", func() { @@ -236,7 +239,8 @@ var _ = Describe("Server", func() { }, []byte("invalid")) tracer.EXPECT().DroppedPacket(p.remoteAddr, logging.PacketTypeHandshake, p.Size(), logging.PacketDropUnexpectedPacket) serv.handlePacket(p) - Consistently(conn.dataWritten).ShouldNot(Receive()) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) }) It("decodes the token from the Token field", func() { @@ -260,6 +264,7 @@ var _ = Describe("Server", func() { Version: serv.config.Versions[0], }, make([]byte, protocol.MinInitialPacketSize)) packet.remoteAddr = raddr + conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) serv.handlePacket(packet) Eventually(done).Should(BeClosed()) @@ -284,6 +289,7 @@ var _ = Describe("Server", func() { Version: serv.config.Versions[0], }, make([]byte, protocol.MinInitialPacketSize)) packet.remoteAddr = raddr + conn.EXPECT().WriteTo(gomock.Any(), gomock.Any()).MaxTimes(1) tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).MaxTimes(1) serv.handlePacket(packet) Eventually(done).Should(BeClosed()) @@ -360,8 +366,9 @@ var _ = Describe("Server", func() { go func() { defer GinkgoRecover() serv.handlePacket(p) - // the Handshake packet is written by the session - Consistently(conn.dataWritten).ShouldNot(Receive()) + // the Handshake packet is written by the session. + // Make sure there are no Write calls on the packet conn. + time.Sleep(50 * time.Millisecond) close(done) }() // make sure we're using a server-generated connection ID @@ -379,23 +386,27 @@ var _ = Describe("Server", func() { DestConnectionID: destConnID, Version: 0x42, }, make([]byte, protocol.MinInitialPacketSize)) - packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + packet.remoteAddr = raddr tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) { Expect(replyHdr.IsLongHeader).To(BeTrue()) Expect(replyHdr.Version).To(BeZero()) Expect(replyHdr.SrcConnectionID).To(Equal(destConnID)) Expect(replyHdr.DestConnectionID).To(Equal(srcConnID)) }) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + Expect(wire.IsVersionNegotiationPacket(b)).To(BeTrue()) + hdr, versions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(b)) + Expect(err).ToNot(HaveOccurred()) + Expect(hdr.DestConnectionID).To(Equal(srcConnID)) + Expect(hdr.SrcConnectionID).To(Equal(destConnID)) + Expect(versions).ToNot(ContainElement(protocol.VersionNumber(0x42))) + return len(b), nil + }) serv.handlePacket(packet) - var write mockPacketConnWrite - Eventually(conn.dataWritten).Should(Receive(&write)) - Expect(write.to.String()).To(Equal("127.0.0.1:1337")) - Expect(wire.IsVersionNegotiationPacket(write.data)).To(BeTrue()) - hdr, versions, err := wire.ParseVersionNegotiationPacket(bytes.NewReader(write.data)) - Expect(err).ToNot(HaveOccurred()) - Expect(hdr.DestConnectionID).To(Equal(srcConnID)) - Expect(hdr.SrcConnectionID).To(Equal(destConnID)) - Expect(versions).ToNot(ContainElement(protocol.VersionNumber(0x42))) + Eventually(done).Should(BeClosed()) }) It("replies with a Retry packet, if a Token is required", func() { @@ -408,23 +419,27 @@ var _ = Describe("Server", func() { Version: protocol.VersionTLS, } packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) - packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + packet.remoteAddr = raddr tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), nil).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, _ []logging.Frame) { Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) Expect(replyHdr.Token).ToNot(BeEmpty()) }) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + replyHdr := parseHeader(b) + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) + Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) + Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + Expect(replyHdr.Token).ToNot(BeEmpty()) + Expect(b[len(b)-16:]).To(Equal(handshake.GetRetryIntegrityTag(b[:len(b)-16], hdr.DestConnectionID)[:])) + return len(b), nil + }) serv.handlePacket(packet) - var write mockPacketConnWrite - Eventually(conn.dataWritten).Should(Receive(&write)) - Expect(write.to.String()).To(Equal("127.0.0.1:1337")) - replyHdr := parseHeader(write.data) - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry)) - Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID)) - Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - Expect(replyHdr.Token).ToNot(BeEmpty()) - Expect(write.data[len(write.data)-16:]).To(Equal(handshake.GetRetryIntegrityTag(write.data[:len(write.data)-16], hdr.DestConnectionID)[:])) + Eventually(done).Should(BeClosed()) }) It("sends an INVALID_TOKEN error, if an invalid retry token is received", func() { @@ -441,7 +456,8 @@ var _ = Describe("Server", func() { } packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) packet.data = append(packet.data, []byte("coalesced packet")...) // add some garbage to simulate a coalesced packet - packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + raddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + packet.remoteAddr = raddr tracer.EXPECT().SentPacket(packet.remoteAddr, gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ net.Addr, replyHdr *logging.Header, _ logging.ByteCount, frames []logging.Frame) { Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) @@ -452,25 +468,28 @@ var _ = Describe("Server", func() { Expect(ccf.IsApplicationError).To(BeFalse()) Expect(ccf.ErrorCode).To(Equal(qerr.InvalidToken)) }) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), raddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + replyHdr := parseHeader(b) + Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + _, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient) + extHdr, err := unpackHeader(opener, replyHdr, b, hdr.Version) + Expect(err).ToNot(HaveOccurred()) + data, err := opener.Open(nil, b[extHdr.ParsedLen():], extHdr.PacketNumber, b[:extHdr.ParsedLen()]) + Expect(err).ToNot(HaveOccurred()) + f, err := wire.NewFrameParser(hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial) + Expect(err).ToNot(HaveOccurred()) + Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) + ccf := f.(*wire.ConnectionCloseFrame) + Expect(ccf.ErrorCode).To(Equal(qerr.InvalidToken)) + Expect(ccf.ReasonPhrase).To(BeEmpty()) + return len(b), nil + }) serv.handlePacket(packet) - var write mockPacketConnWrite - Eventually(conn.dataWritten).Should(Receive(&write)) - Expect(write.to.String()).To(Equal("127.0.0.1:1337")) - replyHdr := parseHeader(write.data) - Expect(replyHdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) - Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - _, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveClient) - extHdr, err := unpackHeader(opener, replyHdr, write.data, hdr.Version) - Expect(err).ToNot(HaveOccurred()) - data, err := opener.Open(nil, write.data[extHdr.ParsedLen():], extHdr.PacketNumber, write.data[:extHdr.ParsedLen()]) - Expect(err).ToNot(HaveOccurred()) - f, err := wire.NewFrameParser(hdr.Version).ParseNext(bytes.NewReader(data), protocol.EncryptionInitial) - Expect(err).ToNot(HaveOccurred()) - Expect(f).To(BeAssignableToTypeOf(&wire.ConnectionCloseFrame{})) - ccf := f.(*wire.ConnectionCloseFrame) - Expect(ccf.ErrorCode).To(Equal(qerr.InvalidToken)) - Expect(ccf.ReasonPhrase).To(BeEmpty()) + Eventually(done).Should(BeClosed()) }) It("doesn't send an INVALID_TOKEN error, if the packet is corrupted", func() { @@ -490,7 +509,8 @@ var _ = Describe("Server", func() { packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} tracer.EXPECT().DroppedPacket(packet.remoteAddr, logging.PacketTypeInitial, packet.Size(), logging.PacketDropPayloadDecryptError) serv.handlePacket(packet) - Consistently(conn.dataWritten).ShouldNot(Receive()) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) }) It("creates a session, if no Token is required", func() { @@ -559,7 +579,8 @@ var _ = Describe("Server", func() { defer GinkgoRecover() serv.handlePacket(p) // the Handshake packet is written by the session - Consistently(conn.dataWritten).ShouldNot(Receive()) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) close(done) }() // make sure we're using a server-generated connection ID @@ -756,7 +777,8 @@ var _ = Describe("Server", func() { defer GinkgoRecover() defer wg.Done() serv.handlePacket(getInitialWithRandomDestConnID()) - Consistently(conn.dataWritten).ShouldNot(Receive()) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) }() } wg.Wait() @@ -764,15 +786,18 @@ var _ = Describe("Server", func() { hdr, _, _, err := wire.ParsePacket(p.data, 0) Expect(err).ToNot(HaveOccurred()) tracer.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), p.remoteAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + rejectHdr := parseHeader(b) + Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(rejectHdr.Version).To(Equal(hdr.Version)) + Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + return len(b), nil + }) serv.handlePacket(p) - var reject mockPacketConnWrite - Eventually(conn.dataWritten).Should(Receive(&reject)) - Expect(reject.to).To(Equal(p.remoteAddr)) - rejectHdr := parseHeader(reject.data) - Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(rejectHdr.Version).To(Equal(hdr.Version)) - Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + Eventually(done).Should(BeClosed()) }) It("doesn't accept new sessions if they were closed in the mean time", func() { @@ -817,7 +842,8 @@ var _ = Describe("Server", func() { tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, gomock.Any()) serv.handlePacket(p) - Consistently(conn.dataWritten).ShouldNot(Receive()) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) Eventually(sessionCreated).Should(BeClosed()) cancel() time.Sleep(scaleDuration(200 * time.Millisecond)) @@ -1034,19 +1060,23 @@ var _ = Describe("Server", func() { } Eventually(func() int32 { return atomic.LoadInt32(&serv.sessionQueueLen) }).Should(BeEquivalentTo(protocol.MaxAcceptQueueSize)) - Consistently(conn.dataWritten).ShouldNot(Receive()) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) p := getInitialWithRandomDestConnID() hdr := parseHeader(p.data) + done := make(chan struct{}) + conn.EXPECT().WriteTo(gomock.Any(), senderAddr).DoAndReturn(func(b []byte, _ net.Addr) (int, error) { + defer close(done) + rejectHdr := parseHeader(b) + Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) + Expect(rejectHdr.Version).To(Equal(hdr.Version)) + Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) + Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + return len(b), nil + }) serv.handlePacket(p) - var reject mockPacketConnWrite - Eventually(conn.dataWritten).Should(Receive(&reject)) - Expect(reject.to).To(Equal(senderAddr)) - rejectHdr := parseHeader(reject.data) - Expect(rejectHdr.Type).To(Equal(protocol.PacketTypeInitial)) - Expect(rejectHdr.Version).To(Equal(hdr.Version)) - Expect(rejectHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID)) - Expect(rejectHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID)) + Eventually(done).Should(BeClosed()) }) It("doesn't accept new sessions if they were closed in the mean time", func() { @@ -1087,7 +1117,8 @@ var _ = Describe("Server", func() { return true }) serv.handlePacket(p) - Consistently(conn.dataWritten).ShouldNot(Receive()) + // make sure there are no Write calls on the packet conn + time.Sleep(50 * time.Millisecond) Eventually(sessionCreated).Should(BeClosed()) cancel() time.Sleep(scaleDuration(200 * time.Millisecond))