diff --git a/benchmark_test.go b/benchmark_test.go index b74d26eb..4539236d 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -24,11 +24,11 @@ import ( ) type linkedConnection struct { - other *Session + other *session c chan []byte } -func newLinkedConnection(other *Session) *linkedConnection { +func newLinkedConnection(other *session) *linkedConnection { c := make(chan []byte, 500) conn := &linkedConnection{ c: c, @@ -98,18 +98,18 @@ var _ = Describe("Benchmarks", func() { connID := protocol.ConnectionID(mrand.Uint32()) c1 := newLinkedConnection(nil) - session1I, err := newSession(c1, version, connID, nil, func(*Session, utils.Stream) {}, func(id protocol.ConnectionID) {}) + session1I, err := newSession(c1, version, connID, nil, func(Session, utils.Stream) {}, func(id protocol.ConnectionID) {}) if err != nil { Expect(err).NotTo(HaveOccurred()) } - session1 := session1I.(*Session) + session1 := session1I.(*session) c2 := newLinkedConnection(session1) - session2I, err := newSession(c2, version, connID, nil, func(*Session, utils.Stream) {}, func(id protocol.ConnectionID) {}) + session2I, err := newSession(c2, version, connID, nil, func(Session, utils.Stream) {}, func(id protocol.ConnectionID) {}) if err != nil { Expect(err).NotTo(HaveOccurred()) } - session2 := session2I.(*Session) + session2 := session2I.(*session) c1.other = session2 key := make([]byte, 16) diff --git a/client.go b/client.go index 100f9b61..43af0343 100644 --- a/client.go +++ b/client.go @@ -217,7 +217,7 @@ func (c *Client) createNewSession(negotiatedVersions []protocol.VersionNumber) e return nil } -func (c *Client) streamCallback(session *Session, stream utils.Stream) {} +func (c *Client) streamCallback(Session, utils.Stream) {} func (c *Client) closeCallback(id protocol.ConnectionID) { utils.Infof("Connection %x closed.", id) diff --git a/client_test.go b/client_test.go index c75773db..4e4088dc 100644 --- a/client_test.go +++ b/client_test.go @@ -20,7 +20,7 @@ import ( var _ = Describe("Client", func() { var ( client *Client - session *mockSession + sess *mockSession versionNegotiateCallbackCalled bool ) @@ -32,9 +32,9 @@ var _ = Describe("Client", func() { return nil }, } - session = &mockSession{connectionID: 0x1337} + sess = &mockSession{connectionID: 0x1337} client.connectionID = 0x1337 - client.session = session + client.session = sess client.version = protocol.Version36 }) @@ -51,7 +51,7 @@ var _ = Describe("Client", func() { client, err = NewClient("quic.clemente.io:1337", nil, nil, nil) Expect(err).ToNot(HaveOccurred()) Expect(client.hostname).To(Equal("quic.clemente.io")) - Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(client.session.(*Session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil()) + Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(client.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(BeNil()) }) It("errors on invalid public header", func() { @@ -78,8 +78,8 @@ var _ = Describe("Client", func() { testErr := errors.New("test error") err := client.Close(testErr) Expect(err).ToNot(HaveOccurred()) - Eventually(session.closed).Should(BeTrue()) - Expect(session.closeReason).To(MatchError(testErr)) + Eventually(sess.closed).Should(BeTrue()) + Expect(sess.closeReason).To(MatchError(testErr)) Expect(client.closed).To(Equal(uint32(1))) Eventually(func() bool { return stoppedListening }).Should(BeTrue()) Eventually(runtime.NumGoroutine()).Should(Equal(numGoRoutines)) @@ -90,8 +90,8 @@ var _ = Describe("Client", func() { client.closed = 1 err := client.Close(errors.New("test error")) Expect(err).ToNot(HaveOccurred()) - Eventually(session.closed).Should(BeFalse()) - Expect(session.closeReason).ToNot(HaveOccurred()) + Eventually(sess.closed).Should(BeFalse()) + Expect(sess.closeReason).ToNot(HaveOccurred()) }) It("creates new sessions with the right parameters", func() { @@ -101,8 +101,8 @@ var _ = Describe("Client", func() { err := client.createNewSession(nil) Expect(err).ToNot(HaveOccurred()) Expect(client.session).ToNot(BeNil()) - Expect(client.session.(*Session).connectionID).To(Equal(client.connectionID)) - Expect(client.session.(*Session).version).To(Equal(client.version)) + Expect(client.session.(*session).connectionID).To(Equal(client.connectionID)) + Expect(client.session.(*session).version).To(Equal(client.version)) err = client.Close(nil) Expect(err).ToNot(HaveOccurred()) @@ -133,7 +133,7 @@ var _ = Describe("Client", func() { stoppedListening = true }() - Expect(session.packetCount).To(BeZero()) + Expect(sess.packetCount).To(BeZero()) ph := PublicHeader{ PacketNumber: 1, PacketNumberLen: protocol.PacketNumberLen2, @@ -145,8 +145,8 @@ var _ = Describe("Client", func() { _, err = serverConn.Write(b.Bytes()) Expect(err).ToNot(HaveOccurred()) - Eventually(func() int { return session.packetCount }).Should(Equal(1)) - Expect(session.closed).To(BeFalse()) + Eventually(func() int { return sess.packetCount }).Should(Equal(1)) + Expect(sess.closed).To(BeFalse()) Eventually(func() bool { return stoppedListening }).Should(BeFalse()) err = client.Close(nil) @@ -168,8 +168,8 @@ var _ = Describe("Client", func() { listenErr = client.Listen() Expect(listenErr).To(HaveOccurred()) - Eventually(session.closed).Should(BeTrue()) - Expect(session.closeReason).To(MatchError(listenErr)) + Eventually(sess.closed).Should(BeTrue()) + Expect(sess.closeReason).To(MatchError(listenErr)) close(done) }) }) @@ -210,19 +210,19 @@ var _ = Describe("Client", func() { startUDPConn() newVersion := protocol.Version35 Expect(newVersion).ToNot(Equal(client.version)) - Expect(session.packetCount).To(BeZero()) + Expect(sess.packetCount).To(BeZero()) client.connectionID = 0x1337 err := client.handlePacket(getVersionNegotiation([]protocol.VersionNumber{newVersion})) Expect(client.version).To(Equal(newVersion)) Expect(client.versionNegotiated).To(BeTrue()) Expect(versionNegotiateCallbackCalled).To(BeTrue()) // it swapped the sessions - Expect(client.session).ToNot(Equal(session)) + Expect(client.session).ToNot(Equal(sess)) Expect(client.connectionID).ToNot(Equal(0x1337)) // it generated a new connection ID Expect(err).ToNot(HaveOccurred()) // it didn't pass the version negoation packet to the session (since it has no payload) - Expect(session.packetCount).To(BeZero()) - Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(client.session.(*Session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{35})) + Expect(sess.packetCount).To(BeZero()) + Expect(*(*[]protocol.VersionNumber)(unsafe.Pointer(reflect.ValueOf(client.session.(*session).cryptoSetup).Elem().FieldByName("negotiatedVersions").UnsafeAddr()))).To(Equal([]protocol.VersionNumber{35})) err = client.Close(nil) Expect(err).ToNot(HaveOccurred()) @@ -236,11 +236,11 @@ var _ = Describe("Client", func() { It("ignores delayed version negotiation packets", func() { // if the version was not yet negotiated, handlePacket would return a VersionNegotiationMismatch error, see above test client.versionNegotiated = true - Expect(session.packetCount).To(BeZero()) + Expect(sess.packetCount).To(BeZero()) err := client.handlePacket(getVersionNegotiation([]protocol.VersionNumber{1})) Expect(err).ToNot(HaveOccurred()) Expect(client.versionNegotiated).To(BeTrue()) - Expect(session.packetCount).To(BeZero()) + Expect(sess.packetCount).To(BeZero()) Expect(versionNegotiateCallbackCalled).To(BeFalse()) }) diff --git a/h2quic/server.go b/h2quic/server.go index 5ec2dd73..cbb1eb86 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -90,7 +90,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error { return server.Serve(conn) } -func (s *Server) handleStreamCb(session *quic.Session, stream utils.Stream) { +func (s *Server) handleStreamCb(session quic.Session, stream utils.Stream) { s.handleStream(session, stream) } diff --git a/interface.go b/interface.go new file mode 100644 index 00000000..918934c9 --- /dev/null +++ b/interface.go @@ -0,0 +1,25 @@ +package quic + +import ( + "net" + + "github.com/lucas-clemente/quic-go/protocol" + "github.com/lucas-clemente/quic-go/utils" +) + +// A Session is a QUIC Session +type Session interface { + // get the next stream opened by the client + // first stream returned has StreamID 3 + AcceptStream() (utils.Stream, error) + // guaranteed to return the smallest unopened stream + // special error for "too many streams, retry later" + OpenStream() (utils.Stream, error) + // TODO: implement this + // blocks until a new stream can be opened, if the maximum number of stream is opened + // OpenStreamSync() (utils.Stream, error) + RemoteAddr() net.Addr + Close(error) error + // TODO: remove this + GetOrOpenStream(protocol.StreamID) (utils.Stream, error) +} diff --git a/session.go b/session.go index 3d750695..817d673c 100644 --- a/session.go +++ b/session.go @@ -32,11 +32,11 @@ type receivedPacket struct { var ( errRstStreamOnInvalidStream = errors.New("RST_STREAM received for unknown stream") errWindowUpdateOnClosedStream = errors.New("WINDOW_UPDATE received for an already closed stream") - errSessionAlreadyClosed = errors.New("Cannot close Session. It was already closed before.") + errSessionAlreadyClosed = errors.New("Cannot close session. It was already closed before.") ) // StreamCallback gets a stream frame and returns a reply frame -type StreamCallback func(*Session, utils.Stream) +type StreamCallback func(Session, utils.Stream) // CryptoChangeCallback is called every time the encryption level changes // Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that @@ -46,7 +46,7 @@ type CryptoChangeCallback func(isForwardSecure bool) type closeCallback func(id protocol.ConnectionID) // A Session is a QUIC session -type Session struct { +type session struct { connectionID protocol.ConnectionID perspective protocol.Perspective version protocol.VersionNumber @@ -100,9 +100,11 @@ type Session struct { timerRead bool } +var _ Session = &session{} + // newSession makes a new session func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback closeCallback) (packetHandler, error) { - session := &Session{ + s := &session{ conn: conn, connectionID: connectionID, perspective: protocol.PerspectiveServer, @@ -114,8 +116,8 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveServer, v), } - session.setup() - cryptoStream, _ := session.GetOrOpenStream(1) + s.setup() + cryptoStream, _ := s.GetOrOpenStream(1) var sourceAddr []byte if udpAddr, ok := conn.RemoteAddr().(*net.UDPAddr); ok { sourceAddr = udpAddr.IP @@ -123,19 +125,19 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol sourceAddr = []byte(conn.RemoteAddr().String()) } var err error - session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, session.connectionParameters, session.aeadChanged) + s.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, s.connectionParameters, s.aeadChanged) if err != nil { return nil, err } - session.packer = newPacketPacker(connectionID, session.cryptoSetup, session.connectionParameters, session.streamFramer, session.perspective, session.version) - session.unpacker = &packetUnpacker{aead: session.cryptoSetup, version: session.version} + s.packer = newPacketPacker(connectionID, s.cryptoSetup, s.connectionParameters, s.streamFramer, s.perspective, s.version) + s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} - return session, err + return s, err } -func newClientSession(pconn net.PacketConn, addr net.Addr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*Session, error) { - session := &Session{ +func newClientSession(pconn net.PacketConn, addr net.Addr, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, streamCallback StreamCallback, closeCallback closeCallback, cryptoChangeCallback CryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*session, error) { + s := &session{ conn: &conn{pconn: pconn, currentAddr: addr}, connectionID: connectionID, perspective: protocol.PerspectiveClient, @@ -147,24 +149,24 @@ func newClientSession(pconn net.PacketConn, addr net.Addr, hostname string, v pr connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v), } - session.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(session.ackAlarmChanged) - session.setup() + s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.ackAlarmChanged) + s.setup() - cryptoStream, _ := session.OpenStream() + cryptoStream, _ := s.OpenStream() var err error - session.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, tlsConfig, session.connectionParameters, session.aeadChanged, negotiatedVersions) + s.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, tlsConfig, s.connectionParameters, s.aeadChanged, negotiatedVersions) if err != nil { return nil, err } - session.packer = newPacketPacker(connectionID, session.cryptoSetup, session.connectionParameters, session.streamFramer, session.perspective, session.version) - session.unpacker = &packetUnpacker{aead: session.cryptoSetup, version: session.version} + s.packer = newPacketPacker(connectionID, s.cryptoSetup, s.connectionParameters, s.streamFramer, s.perspective, s.version) + s.unpacker = &packetUnpacker{aead: s.cryptoSetup, version: s.version} - return session, err + return s, err } // setup is called from newSession and newClientSession and initializes values that are independent of the perspective -func (s *Session) setup() { +func (s *session) setup() { s.rttStats = &congestion.RTTStats{} flowControlManager := flowcontrol.NewFlowControlManager(s.connectionParameters, s.rttStats) @@ -193,7 +195,7 @@ func (s *Session) setup() { } // run the session main loop -func (s *Session) run() { +func (s *session) run() { // Start the crypto stream handler go func() { if err := s.cryptoSetup.HandleCryptoStream(); err != nil { @@ -263,7 +265,7 @@ runLoop: s.runClosed <- struct{}{} } -func (s *Session) maybeResetTimer() { +func (s *session) maybeResetTimer() { nextDeadline := s.lastNetworkActivityTime.Add(s.idleTimeout()) if !s.nextAckScheduledTime.IsZero() { @@ -293,14 +295,14 @@ func (s *Session) maybeResetTimer() { s.currentDeadline = nextDeadline } -func (s *Session) idleTimeout() time.Duration { +func (s *session) idleTimeout() time.Duration { if s.cryptoSetup.HandshakeComplete() { return s.connectionParameters.GetIdleConnectionStateLifetime() } return protocol.InitialIdleTimeout } -func (s *Session) handlePacketImpl(p *receivedPacket) error { +func (s *session) handlePacketImpl(p *receivedPacket) error { if s.perspective == protocol.PerspectiveClient { diversificationNonce := p.publicHeader.DiversificationNonce if len(diversificationNonce) > 0 { @@ -364,7 +366,7 @@ func (s *Session) handlePacketImpl(p *receivedPacket) error { return s.handleFrames(packet.frames) } -func (s *Session) handleFrames(fs []frames.Frame) error { +func (s *session) handleFrames(fs []frames.Frame) error { for _, ff := range fs { var err error frames.LogFrame(ff, false) @@ -407,7 +409,7 @@ func (s *Session) handleFrames(fs []frames.Frame) error { } // handlePacket is called by the server with a new packet -func (s *Session) handlePacket(p *receivedPacket) { +func (s *session) handlePacket(p *receivedPacket) { // Discard packets once the amount of queued packets is larger than // the channel size, protocol.MaxSessionUnprocessedPackets select { @@ -416,7 +418,7 @@ func (s *Session) handlePacket(p *receivedPacket) { } } -func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error { +func (s *session) handleStreamFrame(frame *frames.StreamFrame) error { str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) if err != nil { return err @@ -433,7 +435,7 @@ func (s *Session) handleStreamFrame(frame *frames.StreamFrame) error { return nil } -func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error { +func (s *session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error { if frame.StreamID != 0 { str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) if err != nil { @@ -447,7 +449,7 @@ func (s *Session) handleWindowUpdateFrame(frame *frames.WindowUpdateFrame) error return err } -func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { +func (s *session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { str, err := s.streamsMap.GetOrOpenStream(frame.StreamID) if err != nil { return err @@ -460,7 +462,7 @@ func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { return s.flowControlManager.ResetStream(frame.StreamID, frame.ByteOffset) } -func (s *Session) handleAckFrame(frame *frames.AckFrame) error { +func (s *session) handleAckFrame(frame *frames.AckFrame) error { if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, s.lastNetworkActivityTime); err != nil { return err } @@ -469,7 +471,7 @@ func (s *Session) handleAckFrame(frame *frames.AckFrame) error { // Close the connection. If err is nil it will be set to qerr.PeerGoingAway. // It waits until the run loop has stopped before returning -func (s *Session) Close(e error) error { +func (s *session) Close(e error) error { err := s.closeImpl(e, false) if err == errSessionAlreadyClosed { return nil @@ -481,7 +483,7 @@ func (s *Session) Close(e error) error { } // close the connection. Use this when called from the run loop -func (s *Session) close(e error) error { +func (s *session) close(e error) error { err := s.closeImpl(e, false) if err == errSessionAlreadyClosed { return nil @@ -489,7 +491,7 @@ func (s *Session) close(e error) error { return err } -func (s *Session) closeImpl(e error, remoteClose bool) error { +func (s *session) closeImpl(e error, remoteClose bool) error { // Only close once if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) { return errSessionAlreadyClosed @@ -536,14 +538,14 @@ func (s *Session) closeImpl(e error, remoteClose bool) error { return nil } -func (s *Session) closeStreamsWithError(err error) { +func (s *session) closeStreamsWithError(err error) { s.streamsMap.Iterate(func(str *stream) (bool, error) { str.Cancel(err) return true, nil }) } -func (s *Session) sendPacket() error { +func (s *session) sendPacket() error { // Repeatedly try sending until we don't have any more data, or run out of the congestion window for { err := s.sentPacketHandler.CheckForError() @@ -638,7 +640,7 @@ func (s *Session) sendPacket() error { } } -func (s *Session) sendConnectionClose(quicErr *qerr.QuicError) error { +func (s *session) sendConnectionClose(quicErr *qerr.QuicError) error { packet, err := s.packer.PackConnectionClose(&frames.ConnectionCloseFrame{ErrorCode: quicErr.ErrorCode, ReasonPhrase: quicErr.ErrorMessage}, s.sentPacketHandler.GetLeastUnacked()) if err != nil { return err @@ -650,7 +652,7 @@ func (s *Session) sendConnectionClose(quicErr *qerr.QuicError) error { return s.conn.write(packet.raw) } -func (s *Session) logPacket(packet *packedPacket) { +func (s *session) logPacket(packet *packedPacket) { if !utils.Debug() { // We don't need to allocate the slices for calling the format functions return @@ -665,7 +667,7 @@ func (s *Session) logPacket(packet *packedPacket) { // GetOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed. // Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used. -func (s *Session) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) { +func (s *session) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) { str, err := s.streamsMap.GetOrOpenStream(id) if str != nil { return str, err @@ -675,16 +677,16 @@ func (s *Session) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error) { } // AcceptStream returns the next stream openend by the peer -func (s *Session) AcceptStream() (utils.Stream, error) { +func (s *session) AcceptStream() (utils.Stream, error) { return s.streamsMap.AcceptStream() } // OpenStream opens a stream -func (s *Session) OpenStream() (utils.Stream, error) { +func (s *session) OpenStream() (utils.Stream, error) { return s.streamsMap.OpenStream() } -func (s *Session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) { +func (s *session) queueResetStreamFrame(id protocol.StreamID, offset protocol.ByteCount) { s.packer.QueueControlFrameForNextPacket(&frames.RstStreamFrame{ StreamID: id, ByteOffset: offset, @@ -692,7 +694,7 @@ func (s *Session) queueResetStreamFrame(id protocol.StreamID, offset protocol.By s.scheduleSending() } -func (s *Session) newStream(id protocol.StreamID) (*stream, error) { +func (s *session) newStream(id protocol.StreamID) (*stream, error) { stream, err := newStream(id, s.scheduleSending, s.queueResetStreamFrame, s.flowControlManager) if err != nil { return nil, err @@ -712,7 +714,7 @@ func (s *Session) newStream(id protocol.StreamID) (*stream, error) { // garbageCollectStreams goes through all streams and removes EOF'ed streams // from the streams map. -func (s *Session) garbageCollectStreams() { +func (s *session) garbageCollectStreams() { s.streamsMap.Iterate(func(str *stream) (bool, error) { id := str.StreamID() if str.finished() { @@ -726,20 +728,20 @@ func (s *Session) garbageCollectStreams() { }) } -func (s *Session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { +func (s *session) sendPublicReset(rejectedPacketNumber protocol.PacketNumber) error { utils.Infof("Sending public reset for connection %x, packet number %d", s.connectionID, rejectedPacketNumber) return s.conn.write(writePublicReset(s.connectionID, rejectedPacketNumber, 0)) } // scheduleSending signals that we have data for sending -func (s *Session) scheduleSending() { +func (s *session) scheduleSending() { select { case s.sendingScheduled <- struct{}{}: default: } } -func (s *Session) tryQueueingUndecryptablePacket(p *receivedPacket) { +func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) { if s.cryptoSetup.HandshakeComplete() { return } @@ -750,14 +752,14 @@ func (s *Session) tryQueueingUndecryptablePacket(p *receivedPacket) { s.undecryptablePackets = append(s.undecryptablePackets, p) } -func (s *Session) tryDecryptingQueuedPackets() { +func (s *session) tryDecryptingQueuedPackets() { for _, p := range s.undecryptablePackets { s.handlePacket(p) } s.undecryptablePackets = s.undecryptablePackets[:0] } -func (s *Session) getWindowUpdateFrames() ([]*frames.WindowUpdateFrame, error) { +func (s *session) getWindowUpdateFrames() ([]*frames.WindowUpdateFrame, error) { updates := s.flowControlManager.GetWindowUpdates() res := make([]*frames.WindowUpdateFrame, len(updates)) for i, u := range updates { @@ -766,12 +768,12 @@ func (s *Session) getWindowUpdateFrames() ([]*frames.WindowUpdateFrame, error) { return res, nil } -func (s *Session) ackAlarmChanged(t time.Time) { +func (s *session) ackAlarmChanged(t time.Time) { s.nextAckScheduledTime = t s.maybeResetTimer() } // RemoteAddr returns the net.Addr of the client -func (s *Session) RemoteAddr() net.Addr { +func (s *session) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } diff --git a/session_test.go b/session_test.go index b6b7aa83..5590751d 100644 --- a/session_test.go +++ b/session_test.go @@ -114,8 +114,8 @@ var _ ackhandler.ReceivedPacketHandler = &mockReceivedPacketHandler{} var _ = Describe("Session", func() { var ( - session *Session - clientSession *Session + sess *session + clientSess *session streamCallbackCalled bool closeCallbackCalled bool scfg *handshake.ServerConfig @@ -133,78 +133,78 @@ var _ = Describe("Session", func() { Expect(err).NotTo(HaveOccurred()) scfg, err = handshake.NewServerConfig(kex, certChain) Expect(err).NotTo(HaveOccurred()) - pSession, err := newSession( + pSess, err := newSession( mconn, protocol.Version35, 0, scfg, - func(*Session, utils.Stream) { streamCallbackCalled = true }, + func(Session, utils.Stream) { streamCallbackCalled = true }, func(protocol.ConnectionID) { closeCallbackCalled = true }, ) Expect(err).NotTo(HaveOccurred()) - session = pSession.(*Session) - Expect(session.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream + sess = pSess.(*session) + Expect(sess.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream cpm = &mockConnectionParametersManager{idleTime: 60 * time.Second} - session.connectionParameters = cpm + sess.connectionParameters = cpm - clientSession, err = newClientSession( + clientSess, err = newClientSession( &net.UDPConn{}, &net.UDPAddr{}, "hostname", protocol.Version35, 0, nil, - func(*Session, utils.Stream) { streamCallbackCalled = true }, + func(Session, utils.Stream) { streamCallbackCalled = true }, func(protocol.ConnectionID) { closeCallbackCalled = true }, func(isForwardSecure bool) {}, nil, ) Expect(err).ToNot(HaveOccurred()) - Expect(clientSession.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream + Expect(clientSess.streamsMap.openStreams).To(HaveLen(1)) // Crypto stream }) Context("source address", func() { It("uses the IP address if given an UDP connection", func() { conn := &conn{currentAddr: &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200)[12:], Port: 1337}} - session, err := newSession( + sess, err := newSession( conn, protocol.VersionWhatever, 0, scfg, - func(*Session, utils.Stream) { streamCallbackCalled = true }, + func(Session, utils.Stream) { streamCallbackCalled = true }, func(protocol.ConnectionID) { closeCallbackCalled = true }, ) Expect(err).ToNot(HaveOccurred()) - Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(session.(*Session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte{192, 168, 100, 200})) + Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte{192, 168, 100, 200})) }) It("uses the string representation of the remote addresses if not given a UDP connection", func() { conn := &conn{ currentAddr: &net.TCPAddr{IP: net.IPv4(192, 168, 100, 200)[12:], Port: 1337}, } - session, err := newSession( + sess, err := newSession( conn, protocol.VersionWhatever, 0, scfg, - func(*Session, utils.Stream) { streamCallbackCalled = true }, + func(Session, utils.Stream) { streamCallbackCalled = true }, func(protocol.ConnectionID) { closeCallbackCalled = true }, ) Expect(err).ToNot(HaveOccurred()) - Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(session.(*Session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte("192.168.100.200:1337"))) + Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte("192.168.100.200:1337"))) }) }) Context("when handling stream frames", func() { It("makes new streams", func() { - session.handleStreamFrame(&frames.StreamFrame{ + sess.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, }) Expect(streamCallbackCalled).To(BeTrue()) p := make([]byte, 4) - str, err := session.streamsMap.GetOrOpenStream(5) + str, err := sess.streamsMap.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) _, err = str.Read(p) @@ -213,9 +213,9 @@ var _ = Describe("Session", func() { }) It("does not reject existing streams with even StreamIDs", func() { - _, err := session.GetOrOpenStream(5) + _, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) - err = session.handleStreamFrame(&frames.StreamFrame{ + err = sess.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, }) @@ -223,20 +223,20 @@ var _ = Describe("Session", func() { }) It("handles existing streams", func() { - session.handleStreamFrame(&frames.StreamFrame{ + sess.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, Data: []byte{0xde, 0xca}, }) - numOpenStreams := len(session.streamsMap.openStreams) + numOpenStreams := len(sess.streamsMap.openStreams) Expect(streamCallbackCalled).To(BeTrue()) - session.handleStreamFrame(&frames.StreamFrame{ + sess.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, Offset: 2, Data: []byte{0xfb, 0xad}, }) - Expect(session.streamsMap.openStreams).To(HaveLen(numOpenStreams)) + Expect(sess.streamsMap.openStreams).To(HaveLen(numOpenStreams)) p := make([]byte, 4) - str, _ := session.streamsMap.GetOrOpenStream(5) + str, _ := sess.streamsMap.GetOrOpenStream(5) Expect(str).ToNot(BeNil()) _, err := str.Read(p) Expect(err).ToNot(HaveOccurred()) @@ -244,128 +244,128 @@ var _ = Describe("Session", func() { }) It("does not delete streams with Close()", func() { - str, err := session.GetOrOpenStream(5) + str, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) str.Close() - session.garbageCollectStreams() - str, err = session.streamsMap.GetOrOpenStream(5) + sess.garbageCollectStreams() + str, err = sess.streamsMap.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) }) It("does not delete streams with FIN bit", func() { - session.handleStreamFrame(&frames.StreamFrame{ + sess.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, FinBit: true, }) - numOpenStreams := len(session.streamsMap.openStreams) - str, _ := session.streamsMap.GetOrOpenStream(5) + numOpenStreams := len(sess.streamsMap.openStreams) + str, _ := sess.streamsMap.GetOrOpenStream(5) Expect(str).ToNot(BeNil()) Expect(streamCallbackCalled).To(BeTrue()) p := make([]byte, 4) _, err := str.Read(p) Expect(err).To(MatchError(io.EOF)) Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) - session.garbageCollectStreams() - Expect(session.streamsMap.openStreams).To(HaveLen(numOpenStreams)) - str, _ = session.streamsMap.GetOrOpenStream(5) + sess.garbageCollectStreams() + Expect(sess.streamsMap.openStreams).To(HaveLen(numOpenStreams)) + str, _ = sess.streamsMap.GetOrOpenStream(5) Expect(str).ToNot(BeNil()) }) It("deletes streams with FIN bit & close", func() { - session.handleStreamFrame(&frames.StreamFrame{ + sess.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, FinBit: true, }) - numOpenStreams := len(session.streamsMap.openStreams) - str, _ := session.streamsMap.GetOrOpenStream(5) + numOpenStreams := len(sess.streamsMap.openStreams) + str, _ := sess.streamsMap.GetOrOpenStream(5) Expect(str).ToNot(BeNil()) Expect(streamCallbackCalled).To(BeTrue()) p := make([]byte, 4) _, err := str.Read(p) Expect(err).To(MatchError(io.EOF)) Expect(p).To(Equal([]byte{0xde, 0xca, 0xfb, 0xad})) - session.garbageCollectStreams() - Expect(session.streamsMap.openStreams).To(HaveLen(numOpenStreams)) - str, _ = session.streamsMap.GetOrOpenStream(5) + sess.garbageCollectStreams() + Expect(sess.streamsMap.openStreams).To(HaveLen(numOpenStreams)) + str, _ = sess.streamsMap.GetOrOpenStream(5) Expect(str).ToNot(BeNil()) // We still need to close the stream locally str.Close() // ... and simulate that we actually the FIN str.sentFin() - session.garbageCollectStreams() - Expect(len(session.streamsMap.openStreams)).To(BeNumerically("<", numOpenStreams)) - str, err = session.streamsMap.GetOrOpenStream(5) + sess.garbageCollectStreams() + Expect(len(sess.streamsMap.openStreams)).To(BeNumerically("<", numOpenStreams)) + str, err = sess.streamsMap.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) Expect(str).To(BeNil()) // flow controller should have been notified - _, err = session.flowControlManager.SendWindowSize(5) + _, err = sess.flowControlManager.SendWindowSize(5) Expect(err).To(MatchError("Error accessing the flowController map.")) }) It("cancels streams with error", func() { - session.garbageCollectStreams() + sess.garbageCollectStreams() testErr := errors.New("test") - session.handleStreamFrame(&frames.StreamFrame{ + sess.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, Data: []byte{0xde, 0xca, 0xfb, 0xad}, }) - str, err := session.streamsMap.GetOrOpenStream(5) + str, err := sess.streamsMap.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) Expect(streamCallbackCalled).To(BeTrue()) p := make([]byte, 4) _, err = str.Read(p) Expect(err).ToNot(HaveOccurred()) - session.closeStreamsWithError(testErr) + sess.closeStreamsWithError(testErr) _, err = str.Read(p) Expect(err).To(MatchError(testErr)) - session.garbageCollectStreams() - str, err = session.streamsMap.GetOrOpenStream(5) + sess.garbageCollectStreams() + str, err = sess.streamsMap.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) Expect(str).To(BeNil()) }) It("cancels empty streams with error", func() { testErr := errors.New("test") - session.GetOrOpenStream(5) - str, err := session.streamsMap.GetOrOpenStream(5) + sess.GetOrOpenStream(5) + str, err := sess.streamsMap.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) - session.closeStreamsWithError(testErr) + sess.closeStreamsWithError(testErr) _, err = str.Read([]byte{0}) Expect(err).To(MatchError(testErr)) - session.garbageCollectStreams() - str, err = session.streamsMap.GetOrOpenStream(5) + sess.garbageCollectStreams() + str, err = sess.streamsMap.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) Expect(str).To(BeNil()) }) It("informs the FlowControlManager about new streams", func() { // since the stream doesn't yet exist, this will throw an error - err := session.flowControlManager.UpdateHighestReceived(5, 1000) + err := sess.flowControlManager.UpdateHighestReceived(5, 1000) Expect(err).To(HaveOccurred()) - session.GetOrOpenStream(5) - err = session.flowControlManager.UpdateHighestReceived(5, 2000) + sess.GetOrOpenStream(5) + err = sess.flowControlManager.UpdateHighestReceived(5, 2000) Expect(err).ToNot(HaveOccurred()) }) It("ignores streams that existed previously", func() { - session.handleStreamFrame(&frames.StreamFrame{ + sess.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, Data: []byte{}, FinBit: true, }) - str, _ := session.streamsMap.GetOrOpenStream(5) + str, _ := sess.streamsMap.GetOrOpenStream(5) Expect(str).ToNot(BeNil()) _, err := str.Read([]byte{0}) Expect(err).To(MatchError(io.EOF)) str.Close() str.sentFin() - session.garbageCollectStreams() - err = session.handleStreamFrame(&frames.StreamFrame{ + sess.garbageCollectStreams() + err = sess.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, Data: []byte{}, }) @@ -375,9 +375,9 @@ var _ = Describe("Session", func() { Context("handling RST_STREAM frames", func() { It("closes the streams for writing", func() { - s, err := session.GetOrOpenStream(5) + s, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) - err = session.handleRstStreamFrame(&frames.RstStreamFrame{ + err = sess.handleRstStreamFrame(&frames.RstStreamFrame{ StreamID: 5, ErrorCode: 42, }) @@ -388,13 +388,13 @@ var _ = Describe("Session", func() { }) It("doesn't close the stream for reading", func() { - s, err := session.GetOrOpenStream(5) + s, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) - session.handleStreamFrame(&frames.StreamFrame{ + sess.handleStreamFrame(&frames.StreamFrame{ StreamID: 5, Data: []byte("foobar"), }) - err = session.handleRstStreamFrame(&frames.RstStreamFrame{ + err = sess.handleRstStreamFrame(&frames.RstStreamFrame{ StreamID: 5, ErrorCode: 42, ByteOffset: 6, @@ -407,15 +407,15 @@ var _ = Describe("Session", func() { }) It("queues a RST_STERAM frame with the correct offset", func() { - str, err := session.GetOrOpenStream(5) + str, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) str.(*stream).writeOffset = 0x1337 - err = session.handleRstStreamFrame(&frames.RstStreamFrame{ + err = sess.handleRstStreamFrame(&frames.RstStreamFrame{ StreamID: 5, }) Expect(err).ToNot(HaveOccurred()) - Expect(session.packer.controlFrames).To(HaveLen(1)) - Expect(session.packer.controlFrames[0].(*frames.RstStreamFrame)).To(Equal(&frames.RstStreamFrame{ + Expect(sess.packer.controlFrames).To(HaveLen(1)) + Expect(sess.packer.controlFrames[0].(*frames.RstStreamFrame)).To(Equal(&frames.RstStreamFrame{ StreamID: 5, ByteOffset: 0x1337, })) @@ -423,35 +423,35 @@ var _ = Describe("Session", func() { }) It("doesn't queue a RST_STREAM for a stream that it already sent a FIN on", func() { - str, err := session.GetOrOpenStream(5) + str, err := sess.GetOrOpenStream(5) str.(*stream).sentFin() str.Close() - err = session.handleRstStreamFrame(&frames.RstStreamFrame{ + err = sess.handleRstStreamFrame(&frames.RstStreamFrame{ StreamID: 5, }) Expect(err).ToNot(HaveOccurred()) - Expect(session.packer.controlFrames).To(BeEmpty()) + Expect(sess.packer.controlFrames).To(BeEmpty()) Expect(str.(*stream).finished()).To(BeTrue()) }) It("passes the byte offset to the flow controller", func() { - session.streamsMap.GetOrOpenStream(5) - session.flowControlManager = newMockFlowControlHandler() - err := session.handleRstStreamFrame(&frames.RstStreamFrame{ + sess.streamsMap.GetOrOpenStream(5) + sess.flowControlManager = newMockFlowControlHandler() + err := sess.handleRstStreamFrame(&frames.RstStreamFrame{ StreamID: 5, ByteOffset: 0x1337, }) Expect(err).ToNot(HaveOccurred()) - Expect(session.flowControlManager.(*mockFlowControlHandler).highestReceivedForStream).To(Equal(protocol.StreamID(5))) - Expect(session.flowControlManager.(*mockFlowControlHandler).highestReceived).To(Equal(protocol.ByteCount(0x1337))) + Expect(sess.flowControlManager.(*mockFlowControlHandler).highestReceivedForStream).To(Equal(protocol.StreamID(5))) + Expect(sess.flowControlManager.(*mockFlowControlHandler).highestReceived).To(Equal(protocol.ByteCount(0x1337))) }) It("returns errors from the flow controller", func() { - session.streamsMap.GetOrOpenStream(5) - session.flowControlManager = newMockFlowControlHandler() + sess.streamsMap.GetOrOpenStream(5) + sess.flowControlManager = newMockFlowControlHandler() testErr := errors.New("flow control violation") - session.flowControlManager.(*mockFlowControlHandler).flowControlViolation = testErr - err := session.handleRstStreamFrame(&frames.RstStreamFrame{ + sess.flowControlManager.(*mockFlowControlHandler).flowControlViolation = testErr + err := sess.handleRstStreamFrame(&frames.RstStreamFrame{ StreamID: 5, ByteOffset: 0x1337, }) @@ -459,7 +459,7 @@ var _ = Describe("Session", func() { }) It("ignores the error when the stream is not known", func() { - err := session.handleFrames([]frames.Frame{&frames.RstStreamFrame{ + err := sess.handleFrames([]frames.Frame{&frames.RstStreamFrame{ StreamID: 5, ErrorCode: 42, }}) @@ -468,12 +468,12 @@ var _ = Describe("Session", func() { It("queues a RST_STREAM when a stream gets reset locally", func() { testErr := errors.New("testErr") - str, err := session.streamsMap.GetOrOpenStream(5) + str, err := sess.streamsMap.GetOrOpenStream(5) str.writeOffset = 0x1337 Expect(err).ToNot(HaveOccurred()) str.Reset(testErr) - Expect(session.packer.controlFrames).To(HaveLen(1)) - Expect(session.packer.controlFrames[0]).To(Equal(&frames.RstStreamFrame{ + Expect(sess.packer.controlFrames).To(HaveLen(1)) + Expect(sess.packer.controlFrames[0]).To(Equal(&frames.RstStreamFrame{ StreamID: 5, ByteOffset: 0x1337, })) @@ -482,33 +482,33 @@ var _ = Describe("Session", func() { It("doesn't queue another RST_STREAM, when it receives an RST_STREAM as a response for the first", func() { testErr := errors.New("testErr") - str, err := session.streamsMap.GetOrOpenStream(5) + str, err := sess.streamsMap.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) str.Reset(testErr) - Expect(session.packer.controlFrames).To(HaveLen(1)) - err = session.handleRstStreamFrame(&frames.RstStreamFrame{ + Expect(sess.packer.controlFrames).To(HaveLen(1)) + err = sess.handleRstStreamFrame(&frames.RstStreamFrame{ StreamID: 5, ByteOffset: 0x42, }) Expect(err).ToNot(HaveOccurred()) - Expect(session.packer.controlFrames).To(HaveLen(1)) + Expect(sess.packer.controlFrames).To(HaveLen(1)) }) }) Context("handling WINDOW_UPDATE frames", func() { It("updates the Flow Control Window of a stream", func() { - _, err := session.GetOrOpenStream(5) + _, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) - err = session.handleWindowUpdateFrame(&frames.WindowUpdateFrame{ + err = sess.handleWindowUpdateFrame(&frames.WindowUpdateFrame{ StreamID: 5, ByteOffset: 100, }) Expect(err).ToNot(HaveOccurred()) - Expect(session.flowControlManager.SendWindowSize(5)).To(Equal(protocol.ByteCount(100))) + Expect(sess.flowControlManager.SendWindowSize(5)).To(Equal(protocol.ByteCount(100))) }) It("updates the Flow Control Window of the connection", func() { - err := session.handleWindowUpdateFrame(&frames.WindowUpdateFrame{ + err := sess.handleWindowUpdateFrame(&frames.WindowUpdateFrame{ StreamID: 0, ByteOffset: 0x800000, }) @@ -516,22 +516,22 @@ var _ = Describe("Session", func() { }) It("opens a new stream when receiving a WINDOW_UPDATE for an unknown stream", func() { - err := session.handleWindowUpdateFrame(&frames.WindowUpdateFrame{ + err := sess.handleWindowUpdateFrame(&frames.WindowUpdateFrame{ StreamID: 5, ByteOffset: 1337, }) Expect(err).ToNot(HaveOccurred()) - str, err := session.streamsMap.GetOrOpenStream(5) + str, err := sess.streamsMap.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) Expect(str).ToNot(BeNil()) }) It("errors when receiving a WindowUpdateFrame for a closed stream", func() { - session.handleStreamFrame(&frames.StreamFrame{StreamID: 5}) - err := session.streamsMap.RemoveStream(5) + sess.handleStreamFrame(&frames.StreamFrame{StreamID: 5}) + err := sess.streamsMap.RemoveStream(5) Expect(err).ToNot(HaveOccurred()) - session.garbageCollectStreams() - err = session.handleWindowUpdateFrame(&frames.WindowUpdateFrame{ + sess.garbageCollectStreams() + err = sess.handleWindowUpdateFrame(&frames.WindowUpdateFrame{ StreamID: 5, ByteOffset: 1337, }) @@ -539,11 +539,11 @@ var _ = Describe("Session", func() { }) It("ignores errors when receiving a WindowUpdateFrame for a closed stream", func() { - session.handleStreamFrame(&frames.StreamFrame{StreamID: 5}) - err := session.streamsMap.RemoveStream(5) + sess.handleStreamFrame(&frames.StreamFrame{StreamID: 5}) + err := sess.streamsMap.RemoveStream(5) Expect(err).ToNot(HaveOccurred()) - session.garbageCollectStreams() - err = session.handleFrames([]frames.Frame{&frames.WindowUpdateFrame{ + sess.garbageCollectStreams() + err = sess.handleFrames([]frames.Frame{&frames.WindowUpdateFrame{ StreamID: 5, ByteOffset: 1337, }}) @@ -552,28 +552,28 @@ var _ = Describe("Session", func() { }) It("handles PING frames", func() { - err := session.handleFrames([]frames.Frame{&frames.PingFrame{}}) + err := sess.handleFrames([]frames.Frame{&frames.PingFrame{}}) Expect(err).NotTo(HaveOccurred()) }) It("handles BLOCKED frames", func() { - err := session.handleFrames([]frames.Frame{&frames.BlockedFrame{}}) + err := sess.handleFrames([]frames.Frame{&frames.BlockedFrame{}}) Expect(err).NotTo(HaveOccurred()) }) It("errors on GOAWAY frames", func() { - err := session.handleFrames([]frames.Frame{&frames.GoawayFrame{}}) + err := sess.handleFrames([]frames.Frame{&frames.GoawayFrame{}}) Expect(err).To(MatchError("unimplemented: handling GOAWAY frames")) }) It("handles STOP_WAITING frames", func() { - err := session.handleFrames([]frames.Frame{&frames.StopWaitingFrame{LeastUnacked: 10}}) + err := sess.handleFrames([]frames.Frame{&frames.StopWaitingFrame{LeastUnacked: 10}}) Expect(err).NotTo(HaveOccurred()) }) It("handles CONNECTION_CLOSE frames", func() { - str, _ := session.GetOrOpenStream(5) - err := session.handleFrames([]frames.Frame{&frames.ConnectionCloseFrame{ErrorCode: 42, ReasonPhrase: "foobar"}}) + str, _ := sess.GetOrOpenStream(5) + err := sess.handleFrames([]frames.Frame{&frames.ConnectionCloseFrame{ErrorCode: 42, ReasonPhrase: "foobar"}}) Expect(err).NotTo(HaveOccurred()) _, err = str.Read([]byte{0}) Expect(err).To(MatchError(qerr.Error(42, "foobar"))) @@ -582,7 +582,7 @@ var _ = Describe("Session", func() { Context("accepting streams", func() { It("waits for new streams", func() { // stream 1 was already opened - str, err := session.AcceptStream() + str, err := sess.AcceptStream() Expect(err).ToNot(HaveOccurred()) Expect(str.StreamID()).To(Equal(protocol.StreamID(1))) str = nil @@ -590,11 +590,11 @@ var _ = Describe("Session", func() { go func() { defer GinkgoRecover() var err error - str, err = session.AcceptStream() + str, err = sess.AcceptStream() Expect(err).ToNot(HaveOccurred()) }() Consistently(func() utils.Stream { return str }).Should(BeNil()) - session.handleStreamFrame(&frames.StreamFrame{ + sess.handleStreamFrame(&frames.StreamFrame{ StreamID: 3, }) Eventually(func() utils.Stream { return str }).ShouldNot(BeNil()) @@ -602,31 +602,31 @@ var _ = Describe("Session", func() { }) It("stops accepting when the session is closed", func() { - session.AcceptStream() // accept stream 1 + sess.AcceptStream() // accept stream 1 testErr := errors.New("testErr") var err error go func() { - _, err = session.AcceptStream() + _, err = sess.AcceptStream() }() - go session.run() + go sess.run() Consistently(func() error { return err }).ShouldNot(HaveOccurred()) - session.Close(testErr) + sess.Close(testErr) Eventually(func() error { return err }).Should(HaveOccurred()) Expect(err).To(MatchError(qerr.ToQuicError(testErr))) }) It("stops accepting when the session is closed after version negotiation", func() { - session.AcceptStream() // accept stream 1 + sess.AcceptStream() // accept stream 1 testErr := errCloseSessionForNewVersion var err error go func() { - _, err = session.AcceptStream() + _, err = sess.AcceptStream() }() - go session.run() + go sess.run() Consistently(func() error { return err }).ShouldNot(HaveOccurred()) - session.Close(testErr) + sess.Close(testErr) Eventually(func() error { return err }).Should(HaveOccurred()) Expect(err).To(MatchError(testErr)) }) @@ -640,32 +640,32 @@ var _ = Describe("Session", func() { BeforeEach(func() { time.Sleep(10 * time.Millisecond) // Wait for old goroutines to finish nGoRoutinesBefore = runtime.NumGoroutine() - go session.run() + go sess.run() Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore + 2)) }) It("shuts down without error", func() { - session.Close(nil) + sess.Close(nil) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) Expect(mconn.written).To(HaveLen(1)) Expect(mconn.written[0][len(mconn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0})) Expect(closeCallbackCalled).To(BeTrue()) - Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close() + Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close() }) It("only closes once", func() { - session.Close(nil) - session.Close(nil) + sess.Close(nil) + sess.Close(nil) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) Expect(mconn.written).To(HaveLen(1)) - Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close() + Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close() }) It("closes streams with proper error", func() { testErr := errors.New("test error") - s, err := session.GetOrOpenStream(5) + s, err := sess.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) - session.Close(testErr) + sess.Close(testErr) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) Expect(closeCallbackCalled).To(BeTrue()) n, err := s.Read([]byte{0}) @@ -674,14 +674,14 @@ var _ = Describe("Session", func() { n, err = s.Write([]byte{0}) Expect(n).To(BeZero()) Expect(err.Error()).To(ContainSubstring(testErr.Error())) - Expect(session.runClosed).ToNot(Receive()) // channel should be drained by Close() + Expect(sess.runClosed).ToNot(Receive()) // channel should be drained by Close() }) It("closes the session in order to replace it with another QUIC version", func() { - session.Close(errCloseSessionForNewVersion) + sess.Close(errCloseSessionForNewVersion) Expect(closeCallbackCalled).To(BeFalse()) Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore)) - Expect(atomic.LoadUint32(&session.closed) != 0).To(BeTrue()) + Expect(atomic.LoadUint32(&sess.closed) != 0).To(BeTrue()) Expect(mconn.written).To(BeEmpty()) // no CONNECTION_CLOSE or PUBLIC_RESET sent }) }) @@ -690,98 +690,98 @@ var _ = Describe("Session", func() { var hdr *PublicHeader BeforeEach(func() { - session.unpacker = &mockUnpacker{} - clientSession.unpacker = &mockUnpacker{} + sess.unpacker = &mockUnpacker{} + clientSess.unpacker = &mockUnpacker{} hdr = &PublicHeader{PacketNumberLen: protocol.PacketNumberLen6} }) It("sets the {last,largest}RcvdPacketNumber", func() { hdr.PacketNumber = 5 - err := session.handlePacketImpl(&receivedPacket{publicHeader: hdr}) + err := sess.handlePacketImpl(&receivedPacket{publicHeader: hdr}) Expect(err).ToNot(HaveOccurred()) - Expect(session.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) - Expect(session.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) + Expect(sess.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) + Expect(sess.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) }) It("sets the {last,largest}RcvdPacketNumber, for an out-of-order packet", func() { hdr.PacketNumber = 5 - err := session.handlePacketImpl(&receivedPacket{publicHeader: hdr}) + err := sess.handlePacketImpl(&receivedPacket{publicHeader: hdr}) Expect(err).ToNot(HaveOccurred()) - Expect(session.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) - Expect(session.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) + Expect(sess.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) + Expect(sess.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) hdr.PacketNumber = 3 - err = session.handlePacketImpl(&receivedPacket{publicHeader: hdr}) + err = sess.handlePacketImpl(&receivedPacket{publicHeader: hdr}) Expect(err).ToNot(HaveOccurred()) - Expect(session.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(3))) - Expect(session.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) + Expect(sess.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(3))) + Expect(sess.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) }) It("ignores duplicate packets", func() { hdr.PacketNumber = 5 - err := session.handlePacketImpl(&receivedPacket{publicHeader: hdr}) + err := sess.handlePacketImpl(&receivedPacket{publicHeader: hdr}) Expect(err).ToNot(HaveOccurred()) - err = session.handlePacketImpl(&receivedPacket{publicHeader: hdr}) + err = sess.handlePacketImpl(&receivedPacket{publicHeader: hdr}) Expect(err).ToNot(HaveOccurred()) }) It("ignores packets smaller than the highest LeastUnacked of a StopWaiting", func() { - err := session.receivedPacketHandler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: 10}) + err := sess.receivedPacketHandler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: 10}) Expect(err).ToNot(HaveOccurred()) hdr.PacketNumber = 5 - err = session.handlePacketImpl(&receivedPacket{publicHeader: hdr}) + err = sess.handlePacketImpl(&receivedPacket{publicHeader: hdr}) Expect(err).ToNot(HaveOccurred()) }) It("passes the diversification nonce to the cryptoSetup, if it is a client", func() { hdr.PacketNumber = 5 hdr.DiversificationNonce = []byte("foobar") - err := clientSession.handlePacketImpl(&receivedPacket{publicHeader: hdr}) + err := clientSess.handlePacketImpl(&receivedPacket{publicHeader: hdr}) Expect(err).ToNot(HaveOccurred()) - Expect((*[]byte)(unsafe.Pointer(reflect.ValueOf(clientSession.cryptoSetup).Elem().FieldByName("diversificationNonce").UnsafeAddr()))).To(Equal(&hdr.DiversificationNonce)) + Expect((*[]byte)(unsafe.Pointer(reflect.ValueOf(clientSess.cryptoSetup).Elem().FieldByName("diversificationNonce").UnsafeAddr()))).To(Equal(&hdr.DiversificationNonce)) }) Context("updating the remote address", func() { It("sets the remote address", func() { remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} - Expect(session.conn.(*mockConnection).remoteAddr).ToNot(Equal(remoteIP)) + Expect(sess.conn.(*mockConnection).remoteAddr).ToNot(Equal(remoteIP)) p := receivedPacket{ remoteAddr: remoteIP, publicHeader: &PublicHeader{PacketNumber: 1337}, } - err := session.handlePacketImpl(&p) + err := sess.handlePacketImpl(&p) Expect(err).ToNot(HaveOccurred()) - Expect(session.conn.(*mockConnection).remoteAddr).To(Equal(remoteIP)) + Expect(sess.conn.(*mockConnection).remoteAddr).To(Equal(remoteIP)) }) It("doesn't change the remote address if authenticating the packet fails", func() { remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} attackerIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 102)} - session.conn.(*mockConnection).remoteAddr = remoteIP + sess.conn.(*mockConnection).remoteAddr = remoteIP // use the real packetUnpacker here, to make sure this test fails if the error code for failed decryption changes - session.unpacker = &packetUnpacker{} - session.unpacker.(*packetUnpacker).aead = &crypto.NullAEAD{} + sess.unpacker = &packetUnpacker{} + sess.unpacker.(*packetUnpacker).aead = &crypto.NullAEAD{} p := receivedPacket{ remoteAddr: attackerIP, publicHeader: &PublicHeader{PacketNumber: 1337}, } - err := session.handlePacketImpl(&p) + err := sess.handlePacketImpl(&p) quicErr := err.(*qerr.QuicError) Expect(quicErr.ErrorCode).To(Equal(qerr.DecryptionFailure)) - Expect(session.conn.(*mockConnection).remoteAddr).To(Equal(remoteIP)) + Expect(sess.conn.(*mockConnection).remoteAddr).To(Equal(remoteIP)) }) It("sets the remote address, if the packet is authenticated, but unpacking fails for another reason", func() { testErr := errors.New("testErr") remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} - Expect(session.conn.(*mockConnection).remoteAddr).ToNot(Equal(remoteIP)) + Expect(sess.conn.(*mockConnection).remoteAddr).ToNot(Equal(remoteIP)) p := receivedPacket{ remoteAddr: remoteIP, publicHeader: &PublicHeader{PacketNumber: 1337}, } - session.unpacker.(*mockUnpacker).unpackErr = testErr - err := session.handlePacketImpl(&p) + sess.unpacker.(*mockUnpacker).unpackErr = testErr + err := sess.handlePacketImpl(&p) Expect(err).To(MatchError(testErr)) - Expect(session.conn.(*mockConnection).remoteAddr).To(Equal(remoteIP)) + Expect(sess.conn.(*mockConnection).remoteAddr).To(Equal(remoteIP)) }) }) }) @@ -789,22 +789,22 @@ var _ = Describe("Session", func() { Context("sending packets", func() { It("sends ack frames", func() { packetNumber := protocol.PacketNumber(0x035E) - session.receivedPacketHandler.ReceivedPacket(packetNumber, true) - err := session.sendPacket() + sess.receivedPacketHandler.ReceivedPacket(packetNumber, true) + err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(mconn.written).To(HaveLen(1)) Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x5E, 0x03}))) }) It("sends two WindowUpdate frames", func() { - _, err := session.GetOrOpenStream(5) + _, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) - session.flowControlManager.AddBytesRead(5, protocol.ReceiveStreamFlowControlWindow) - err = session.sendPacket() + sess.flowControlManager.AddBytesRead(5, protocol.ReceiveStreamFlowControlWindow) + err = sess.sendPacket() Expect(err).NotTo(HaveOccurred()) - err = session.sendPacket() + err = sess.sendPacket() Expect(err).NotTo(HaveOccurred()) - err = session.sendPacket() + err = sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(mconn.written).To(HaveLen(2)) Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x04, 0x05, 0, 0, 0}))) @@ -812,7 +812,7 @@ var _ = Describe("Session", func() { }) It("sends public reset", func() { - err := session.sendPublicReset(1) + err := sess.sendPublicReset(1) Expect(err).NotTo(HaveOccurred()) Expect(mconn.written).To(HaveLen(1)) Expect(mconn.written[0]).To(ContainSubstring(string([]byte("PRST")))) @@ -822,7 +822,7 @@ var _ = Describe("Session", func() { Context("retransmissions", func() { It("sends a StreamFrame from a packet queued for retransmission", func() { // a StopWaitingFrame is added, so make sure the packet number of the new package is higher than the packet number of the retransmitted packet - session.packer.packetNumberGenerator.next = 0x1337 + 9 + sess.packer.packetNumberGenerator.next = 0x1337 + 9 f := frames.StreamFrame{ StreamID: 0x5, @@ -834,9 +834,9 @@ var _ = Describe("Session", func() { } sph := newMockSentPacketHandler() sph.(*mockSentPacketHandler).retransmissionQueue = []*ackhandler.Packet{&p} - session.sentPacketHandler = sph + sess.sentPacketHandler = sph - err := session.sendPacket() + err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(mconn.written).To(HaveLen(1)) Expect(sph.(*mockSentPacketHandler).requestedStopWaiting).To(BeTrue()) @@ -845,7 +845,7 @@ var _ = Describe("Session", func() { It("sends a StreamFrame from a packet queued for retransmission", func() { // a StopWaitingFrame is added, so make sure the packet number of the new package is higher than the packet number of the retransmitted packet - session.packer.packetNumberGenerator.next = 0x1337 + 9 + sess.packer.packetNumberGenerator.next = 0x1337 + 9 f1 := frames.StreamFrame{ StreamID: 0x5, @@ -865,9 +865,9 @@ var _ = Describe("Session", func() { } sph := newMockSentPacketHandler() sph.(*mockSentPacketHandler).retransmissionQueue = []*ackhandler.Packet{&p1, &p2} - session.sentPacketHandler = sph + sess.sentPacketHandler = sph - err := session.sendPacket() + err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(mconn.written).To(HaveLen(1)) Expect(mconn.written[0]).To(ContainSubstring("foobar")) @@ -876,18 +876,18 @@ var _ = Describe("Session", func() { It("always attaches a StopWaiting to a packet that contains a retransmission", func() { // make sure the packet number of the new package is higher than the packet number of the retransmitted packet - session.packer.packetNumberGenerator.next = 0x1337 + 9 + sess.packer.packetNumberGenerator.next = 0x1337 + 9 f := &frames.StreamFrame{ StreamID: 0x5, Data: bytes.Repeat([]byte{'f'}, int(1.5*float32(protocol.MaxPacketSize))), } - session.streamFramer.AddFrameForRetransmission(f) + sess.streamFramer.AddFrameForRetransmission(f) sph := newMockSentPacketHandler() - session.sentPacketHandler = sph + sess.sentPacketHandler = sph - err := session.sendPacket() + err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(mconn.written).To(HaveLen(2)) sentPackets := sph.(*mockSentPacketHandler).sentPackets @@ -901,20 +901,20 @@ var _ = Describe("Session", func() { It("calls MaybeQueueRTOs even if congestion blocked, so that bytesInFlight is updated", func() { sph := newMockSentPacketHandler() sph.(*mockSentPacketHandler).congestionLimited = true - session.sentPacketHandler = sph - err := session.sendPacket() + sess.sentPacketHandler = sph + err := sess.sendPacket() Expect(err).NotTo(HaveOccurred()) Expect(sph.(*mockSentPacketHandler).maybeQueueRTOsCalled).To(BeTrue()) }) It("retransmits a WindowUpdates if it hasn't already sent a WindowUpdate with a higher ByteOffset", func() { - _, err := session.GetOrOpenStream(5) + _, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) fc := newMockFlowControlHandler() fc.receiveWindow = 0x1000 - session.flowControlManager = fc + sess.flowControlManager = fc sph := newMockSentPacketHandler() - session.sentPacketHandler = sph + sess.sentPacketHandler = sph wuf := &frames.WindowUpdateFrame{ StreamID: 5, ByteOffset: 0x1000, @@ -922,7 +922,7 @@ var _ = Describe("Session", func() { sph.(*mockSentPacketHandler).retransmissionQueue = []*ackhandler.Packet{{ Frames: []frames.Frame{wuf}, }} - err = session.sendPacket() + err = sess.sendPacket() Expect(err).ToNot(HaveOccurred()) sentPackets := sph.(*mockSentPacketHandler).sentPackets Expect(sentPackets).To(HaveLen(1)) @@ -930,43 +930,43 @@ var _ = Describe("Session", func() { }) It("doesn't retransmit WindowUpdates if it already sent a WindowUpdate with a higher ByteOffset", func() { - _, err := session.GetOrOpenStream(5) + _, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) fc := newMockFlowControlHandler() fc.receiveWindow = 0x2000 - session.flowControlManager = fc + sess.flowControlManager = fc sph := newMockSentPacketHandler() - session.sentPacketHandler = sph + sess.sentPacketHandler = sph sph.(*mockSentPacketHandler).retransmissionQueue = []*ackhandler.Packet{{ Frames: []frames.Frame{&frames.WindowUpdateFrame{ StreamID: 5, ByteOffset: 0x1000, }}, }} - err = session.sendPacket() + err = sess.sendPacket() Expect(err).ToNot(HaveOccurred()) Expect(sph.(*mockSentPacketHandler).sentPackets).To(BeEmpty()) }) It("doesn't retransmit WindowUpdates for closed streams", func() { - str, err := session.GetOrOpenStream(5) + str, err := sess.GetOrOpenStream(5) Expect(err).ToNot(HaveOccurred()) // close the stream str.(*stream).sentFin() str.Close() str.(*stream).RegisterRemoteError(nil) - session.garbageCollectStreams() - _, err = session.flowControlManager.SendWindowSize(5) + sess.garbageCollectStreams() + _, err = sess.flowControlManager.SendWindowSize(5) Expect(err).To(MatchError("Error accessing the flowController map.")) sph := newMockSentPacketHandler() - session.sentPacketHandler = sph + sess.sentPacketHandler = sph sph.(*mockSentPacketHandler).retransmissionQueue = []*ackhandler.Packet{{ Frames: []frames.Frame{&frames.WindowUpdateFrame{ StreamID: 5, ByteOffset: 0x1337, }}, }} - err = session.sendPacket() + err = sess.sendPacket() Expect(err).ToNot(HaveOccurred()) sentPackets := sph.(*mockSentPacketHandler).sentPackets Expect(sentPackets).To(BeEmpty()) @@ -975,23 +975,23 @@ var _ = Describe("Session", func() { Context("scheduling sending", func() { It("sends after writing to a stream", func(done Done) { - Expect(session.sendingScheduled).NotTo(Receive()) - s, err := session.GetOrOpenStream(3) + Expect(sess.sendingScheduled).NotTo(Receive()) + s, err := sess.GetOrOpenStream(3) Expect(err).NotTo(HaveOccurred()) go func() { s.Write([]byte("foobar")) close(done) }() - Eventually(session.sendingScheduled).Should(Receive()) + Eventually(sess.sendingScheduled).Should(Receive()) s.(*stream).getDataForWriting(1000) // unblock }) It("sets the timer to the ack timer", func() { rph := &mockReceivedPacketHandler{} rph.nextAckFrame = &frames.AckFrame{LargestAcked: 0x1337} - session.receivedPacketHandler = rph - go session.run() - session.ackAlarmChanged(time.Now().Add(10 * time.Millisecond)) + sess.receivedPacketHandler = rph + go sess.run() + sess.ackAlarmChanged(time.Now().Add(10 * time.Millisecond)) time.Sleep(10 * time.Millisecond) Eventually(func() int { return len(mconn.written) }).ShouldNot(BeZero()) Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x37, 0x13}))) @@ -999,17 +999,17 @@ var _ = Describe("Session", func() { Context("bundling of small packets", func() { It("bundles two small frames of different streams into one packet", func() { - s1, err := session.GetOrOpenStream(5) + s1, err := sess.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) - s2, err := session.GetOrOpenStream(7) + s2, err := sess.GetOrOpenStream(7) Expect(err).NotTo(HaveOccurred()) // Put data directly into the streams s1.(*stream).dataForWriting = []byte("foobar1") s2.(*stream).dataForWriting = []byte("foobar2") - session.scheduleSending() - go session.run() + sess.scheduleSending() + go sess.run() Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(1)) Expect(mconn.written[0]).To(ContainSubstring("foobar1")) @@ -1017,11 +1017,11 @@ var _ = Describe("Session", func() { }) It("sends out two big frames in two packets", func() { - s1, err := session.GetOrOpenStream(5) + s1, err := sess.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) - s2, err := session.GetOrOpenStream(7) + s2, err := sess.GetOrOpenStream(7) Expect(err).NotTo(HaveOccurred()) - go session.run() + go sess.run() go func() { defer GinkgoRecover() _, err2 := s1.Write(bytes.Repeat([]byte{'e'}, 1000)) @@ -1033,9 +1033,9 @@ var _ = Describe("Session", func() { }) It("sends out two small frames that are written to long after one another into two packets", func() { - s, err := session.GetOrOpenStream(5) + s, err := sess.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) - go session.run() + go sess.run() _, err = s.Write([]byte("foobar1")) Expect(err).NotTo(HaveOccurred()) Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(1)) @@ -1046,11 +1046,11 @@ var _ = Describe("Session", func() { It("sends a queued ACK frame only once", func() { packetNumber := protocol.PacketNumber(0x1337) - session.receivedPacketHandler.ReceivedPacket(packetNumber, true) + sess.receivedPacketHandler.ReceivedPacket(packetNumber, true) - s, err := session.GetOrOpenStream(5) + s, err := sess.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) - go session.run() + go sess.run() _, err = s.Write([]byte("foobar1")) Expect(err).NotTo(HaveOccurred()) Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(1)) @@ -1065,15 +1065,15 @@ var _ = Describe("Session", func() { }) It("closes when crypto stream errors", func() { - go session.run() - s, err := session.GetOrOpenStream(3) + go sess.run() + s, err := sess.GetOrOpenStream(3) Expect(err).NotTo(HaveOccurred()) - err = session.handleStreamFrame(&frames.StreamFrame{ + err = sess.handleStreamFrame(&frames.StreamFrame{ StreamID: 1, Data: []byte("4242\x00\x00\x00\x00"), }) Expect(err).NotTo(HaveOccurred()) - Eventually(func() bool { return atomic.LoadUint32(&session.closed) != 0 }).Should(BeTrue()) + Eventually(func() bool { return atomic.LoadUint32(&sess.closed) != 0 }).Should(BeTrue()) _, err = s.Write([]byte{}) Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.InvalidCryptoMessageType)) }) @@ -1084,37 +1084,37 @@ var _ = Describe("Session", func() { hdr := &PublicHeader{ PacketNumber: protocol.PacketNumber(i + 1), } - session.handlePacket(&receivedPacket{publicHeader: hdr, data: []byte("foobar")}) + sess.handlePacket(&receivedPacket{publicHeader: hdr, data: []byte("foobar")}) } - session.run() + sess.run() Expect(mconn.written).To(HaveLen(1)) Expect(mconn.written[0]).To(ContainSubstring(string([]byte("PRST")))) - Expect(session.runClosed).To(Receive()) + Expect(sess.runClosed).To(Receive()) }) It("ignores undecryptable packets after the handshake is complete", func() { - *(*bool)(unsafe.Pointer(reflect.ValueOf(session.cryptoSetup).Elem().FieldByName("receivedForwardSecurePacket").UnsafeAddr())) = true + *(*bool)(unsafe.Pointer(reflect.ValueOf(sess.cryptoSetup).Elem().FieldByName("receivedForwardSecurePacket").UnsafeAddr())) = true for i := 0; i < protocol.MaxUndecryptablePackets; i++ { hdr := &PublicHeader{ PacketNumber: protocol.PacketNumber(i + 1), } - session.handlePacket(&receivedPacket{publicHeader: hdr, data: []byte("foobar")}) + sess.handlePacket(&receivedPacket{publicHeader: hdr, data: []byte("foobar")}) } - go session.run() - Consistently(session.undecryptablePackets).Should(HaveLen(0)) - session.closeImpl(nil, true) - Eventually(session.runClosed).Should(Receive()) + go sess.run() + Consistently(sess.undecryptablePackets).Should(HaveLen(0)) + sess.closeImpl(nil, true) + Eventually(sess.runClosed).Should(Receive()) }) It("unqueues undecryptable packets for later decryption", func() { - session.undecryptablePackets = []*receivedPacket{{ + sess.undecryptablePackets = []*receivedPacket{{ publicHeader: &PublicHeader{PacketNumber: protocol.PacketNumber(42)}, }} - Expect(session.receivedPackets).NotTo(Receive()) - session.tryDecryptingQueuedPackets() - Expect(session.undecryptablePackets).To(BeEmpty()) - Expect(session.receivedPackets).To(Receive()) + Expect(sess.receivedPackets).NotTo(Receive()) + sess.tryDecryptingQueuedPackets() + Expect(sess.undecryptablePackets).To(BeEmpty()) + Expect(sess.receivedPackets).To(Receive()) }) It("calls the cryptoChangeCallback when the AEAD changes", func(done Done) { @@ -1124,10 +1124,10 @@ var _ = Describe("Session", func() { callbackCalled = true callbackCalledWith = p } - session.cryptoChangeCallback = cb - session.cryptoSetup = &mockCryptoSetup{handshakeComplete: false} - session.aeadChanged <- struct{}{} - go session.run() + sess.cryptoChangeCallback = cb + sess.cryptoSetup = &mockCryptoSetup{handshakeComplete: false} + sess.aeadChanged <- struct{}{} + go sess.run() Eventually(func() bool { return callbackCalled }).Should(BeTrue()) Expect(callbackCalledWith).To(BeFalse()) close(done) @@ -1138,54 +1138,54 @@ var _ = Describe("Session", func() { cb := func(p bool) { callbackCalledWith = p } - session.cryptoChangeCallback = cb - session.cryptoSetup = &mockCryptoSetup{handshakeComplete: true} - session.aeadChanged <- struct{}{} - go session.run() + sess.cryptoChangeCallback = cb + sess.cryptoSetup = &mockCryptoSetup{handshakeComplete: true} + sess.aeadChanged <- struct{}{} + go sess.run() Eventually(func() bool { return callbackCalledWith }).Should(BeTrue()) close(done) }) Context("timeouts", func() { It("times out due to no network activity", func(done Done) { - session.lastNetworkActivityTime = time.Now().Add(-time.Hour) - session.run() // Would normally not return + sess.lastNetworkActivityTime = time.Now().Add(-time.Hour) + sess.run() // Would normally not return Expect(mconn.written[0]).To(ContainSubstring("No recent network activity.")) Expect(closeCallbackCalled).To(BeTrue()) - Expect(session.runClosed).To(Receive()) + Expect(sess.runClosed).To(Receive()) close(done) }) It("times out due to non-completed crypto handshake", func(done Done) { - session.sessionCreationTime = time.Now().Add(-time.Hour) - session.run() // Would normally not return + sess.sessionCreationTime = time.Now().Add(-time.Hour) + sess.run() // Would normally not return Expect(mconn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time.")) Expect(closeCallbackCalled).To(BeTrue()) - Expect(session.runClosed).To(Receive()) + Expect(sess.runClosed).To(Receive()) close(done) }) It("does not use ICSL before handshake", func(done Done) { - session.lastNetworkActivityTime = time.Now().Add(-time.Minute) + sess.lastNetworkActivityTime = time.Now().Add(-time.Minute) cpm.idleTime = 99999 * time.Second - session.packer.connectionParameters = session.connectionParameters - session.run() // Would normally not return + sess.packer.connectionParameters = sess.connectionParameters + sess.run() // Would normally not return Expect(mconn.written[0]).To(ContainSubstring("No recent network activity.")) Expect(closeCallbackCalled).To(BeTrue()) - Expect(session.runClosed).To(Receive()) + Expect(sess.runClosed).To(Receive()) close(done) }) It("uses ICSL after handshake", func(done Done) { - // session.lastNetworkActivityTime = time.Now().Add(-time.Minute) - *(*bool)(unsafe.Pointer(reflect.ValueOf(session.cryptoSetup).Elem().FieldByName("receivedForwardSecurePacket").UnsafeAddr())) = true - *(*crypto.AEAD)(unsafe.Pointer(reflect.ValueOf(session.cryptoSetup).Elem().FieldByName("forwardSecureAEAD").UnsafeAddr())) = &crypto.NullAEAD{} + // sess.lastNetworkActivityTime = time.Now().Add(-time.Minute) + *(*bool)(unsafe.Pointer(reflect.ValueOf(sess.cryptoSetup).Elem().FieldByName("receivedForwardSecurePacket").UnsafeAddr())) = true + *(*crypto.AEAD)(unsafe.Pointer(reflect.ValueOf(sess.cryptoSetup).Elem().FieldByName("forwardSecureAEAD").UnsafeAddr())) = &crypto.NullAEAD{} cpm.idleTime = 0 * time.Millisecond - session.packer.connectionParameters = session.connectionParameters - session.run() // Would normally not return + sess.packer.connectionParameters = sess.connectionParameters + sess.run() // Would normally not return Expect(mconn.written[0]).To(ContainSubstring("No recent network activity.")) Expect(closeCallbackCalled).To(BeTrue()) - Expect(session.runClosed).To(Receive()) + Expect(sess.runClosed).To(Receive()) close(done) }) }) @@ -1194,18 +1194,18 @@ var _ = Describe("Session", func() { streamFrame := frames.StreamFrame{StreamID: 5, Data: []byte("foobar")} for i := protocol.PacketNumber(1); i < protocol.MaxTrackedSentPackets+10; i++ { packet := ackhandler.Packet{PacketNumber: protocol.PacketNumber(i), Frames: []frames.Frame{&streamFrame}, Length: 1} - err := session.sentPacketHandler.SentPacket(&packet) + err := sess.sentPacketHandler.SentPacket(&packet) Expect(err).ToNot(HaveOccurred()) } - // now session.sentPacketHandler.CheckForError will return an error - err := session.sendPacket() + // now sess.sentPacketHandler.CheckForError will return an error + err := sess.sendPacket() Expect(err).To(MatchError(ackhandler.ErrTooManyTrackedSentPackets)) }) It("stores up to MaxSessionUnprocessedPackets packets", func(done Done) { // Nothing here should block for i := protocol.PacketNumber(0); i < protocol.MaxSessionUnprocessedPackets+10; i++ { - session.handlePacket(&receivedPacket{}) + sess.handlePacket(&receivedPacket{}) } close(done) }, 0.5) @@ -1214,17 +1214,17 @@ var _ = Describe("Session", func() { // We simulate consistently low RTTs, so that the test works faster n := protocol.PacketNumber(10) for p := protocol.PacketNumber(1); p < n; p++ { - err := session.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: p, Length: 1}) + err := sess.sentPacketHandler.SentPacket(&ackhandler.Packet{PacketNumber: p, Length: 1}) Expect(err).NotTo(HaveOccurred()) time.Sleep(time.Microsecond) ack := &frames.AckFrame{} ack.LargestAcked = p - err = session.sentPacketHandler.ReceivedAck(ack, p, time.Now()) + err = sess.sentPacketHandler.ReceivedAck(ack, p, time.Now()) Expect(err).NotTo(HaveOccurred()) } - session.packer.packetNumberGenerator.next = n + 1 + sess.packer.packetNumberGenerator.next = n + 1 // Now, we send a single packet, and expect that it was retransmitted later - err := session.sentPacketHandler.SentPacket(&ackhandler.Packet{ + err := sess.sentPacketHandler.SentPacket(&ackhandler.Packet{ PacketNumber: n, Length: 1, Frames: []frames.Frame{&frames.StreamFrame{ @@ -1232,27 +1232,27 @@ var _ = Describe("Session", func() { }}, }) Expect(err).NotTo(HaveOccurred()) - go session.run() - session.scheduleSending() + go sess.run() + sess.scheduleSending() Eventually(func() [][]byte { return mconn.written }).ShouldNot(BeEmpty()) Expect(mconn.written[0]).To(ContainSubstring("foobar")) }) Context("getting streams", func() { It("returns a new stream", func() { - str, err := session.GetOrOpenStream(11) + str, err := sess.GetOrOpenStream(11) Expect(err).ToNot(HaveOccurred()) Expect(str).ToNot(BeNil()) Expect(str.StreamID()).To(Equal(protocol.StreamID(11))) }) It("returns a nil-value (not an interface with value nil) for closed streams", func() { - _, err := session.GetOrOpenStream(9) + _, err := sess.GetOrOpenStream(9) Expect(err).ToNot(HaveOccurred()) - session.streamsMap.RemoveStream(9) - session.garbageCollectStreams() - Expect(session.streamsMap.GetOrOpenStream(9)).To(BeNil()) - str, err := session.GetOrOpenStream(9) + sess.streamsMap.RemoveStream(9) + sess.garbageCollectStreams() + Expect(sess.streamsMap.GetOrOpenStream(9)).To(BeNil()) + str, err := sess.GetOrOpenStream(9) Expect(err).ToNot(HaveOccurred()) Expect(str).To(BeNil()) // make sure that the returned value is a plain nil, not an utils.Stream with value nil @@ -1264,16 +1264,16 @@ var _ = Describe("Session", func() { Context("counting streams", func() { It("errors when too many streams are opened", func() { for i := 0; i < 110; i++ { - _, err := session.GetOrOpenStream(protocol.StreamID(i*2 + 1)) + _, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) } - _, err := session.GetOrOpenStream(protocol.StreamID(301)) + _, err := sess.GetOrOpenStream(protocol.StreamID(301)) Expect(err).To(MatchError(qerr.TooManyOpenStreams)) }) It("does not error when many streams are opened and closed", func() { for i := 2; i <= 1000; i++ { - s, err := session.GetOrOpenStream(protocol.StreamID(i*2 + 1)) + s, err := sess.GetOrOpenStream(protocol.StreamID(i*2 + 1)) Expect(err).NotTo(HaveOccurred()) err = s.Close() Expect(err).NotTo(HaveOccurred()) @@ -1281,22 +1281,22 @@ var _ = Describe("Session", func() { s.CloseRemote(0) _, err = s.Read([]byte("a")) Expect(err).To(MatchError(io.EOF)) - session.garbageCollectStreams() + sess.garbageCollectStreams() } }) }) Context("ignoring errors", func() { It("ignores duplicate acks", func() { - session.sentPacketHandler.SentPacket(&ackhandler.Packet{ + sess.sentPacketHandler.SentPacket(&ackhandler.Packet{ PacketNumber: 1, Length: 1, }) - err := session.handleFrames([]frames.Frame{&frames.AckFrame{ + err := sess.handleFrames([]frames.Frame{&frames.AckFrame{ LargestAcked: 1, }}) Expect(err).NotTo(HaveOccurred()) - err = session.handleFrames([]frames.Frame{&frames.AckFrame{ + err = sess.handleFrames([]frames.Frame{&frames.AckFrame{ LargestAcked: 1, }}) Expect(err).NotTo(HaveOccurred()) @@ -1305,9 +1305,9 @@ var _ = Describe("Session", func() { Context("window updates", func() { It("gets stream level window updates", func() { - err := session.flowControlManager.AddBytesRead(1, protocol.ReceiveStreamFlowControlWindow) + err := sess.flowControlManager.AddBytesRead(1, protocol.ReceiveStreamFlowControlWindow) Expect(err).NotTo(HaveOccurred()) - frames, err := session.getWindowUpdateFrames() + frames, err := sess.getWindowUpdateFrames() Expect(err).NotTo(HaveOccurred()) Expect(frames).To(HaveLen(1)) Expect(frames[0].StreamID).To(Equal(protocol.StreamID(1))) @@ -1315,11 +1315,11 @@ var _ = Describe("Session", func() { }) It("gets connection level window updates", func() { - _, err := session.GetOrOpenStream(5) + _, err := sess.GetOrOpenStream(5) Expect(err).NotTo(HaveOccurred()) - err = session.flowControlManager.AddBytesRead(5, protocol.ReceiveConnectionFlowControlWindow) + err = sess.flowControlManager.AddBytesRead(5, protocol.ReceiveConnectionFlowControlWindow) Expect(err).NotTo(HaveOccurred()) - frames, err := session.getWindowUpdateFrames() + frames, err := sess.getWindowUpdateFrames() Expect(err).NotTo(HaveOccurred()) Expect(frames).To(HaveLen(1)) Expect(frames[0].StreamID).To(Equal(protocol.StreamID(0)))