diff --git a/.travis.yml b/.travis.yml index 23faf62c3..94db43c8b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,10 +14,13 @@ go: # setting the GOARCH directly doesn't work, since the value will be overwritten later # so set it to a temporary environment variable first env: - - TRAVIS_GOARCH=amd64 TESTMODE=unit - - TRAVIS_GOARCH=amd64 TESTMODE=integration - - TRAVIS_GOARCH=386 TESTMODE=unit - - TRAVIS_GOARCH=386 TESTMODE=integration + global: + - TIMESCALE_FACTOR=20 + matrix: + - TRAVIS_GOARCH=amd64 TESTMODE=unit + - TRAVIS_GOARCH=amd64 TESTMODE=integration + - TRAVIS_GOARCH=386 TESTMODE=unit + - TRAVIS_GOARCH=386 TESTMODE=integration # second part of the GOARCH workaround # now actually set the GOARCH env variable to the value of the temporary variable set earlier diff --git a/Changelog.md b/Changelog.md index 591752b75..0b357ab8e 100644 --- a/Changelog.md +++ b/Changelog.md @@ -8,6 +8,7 @@ - Add a `quic.Config` option to configure the source address validation - Add a `quic.Config` option to configure the handshake timeout - Add a `quic.Config` option to configure keep-alive +- Implement `net.Conn`-style deadlines for streams - Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/lucas-clemente/quic-go) for details. - Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/lucas-clemente/quic-go/wiki/Logging) for more details. - Rename the `h2quic.QuicRoundTripper` to `h2quic.RoundTripper` diff --git a/appveyor.yml b/appveyor.yml index 1de758e0f..a013bb3bc 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -5,6 +5,7 @@ os: Windows Server 2012 R2 environment: GOPATH: c:\gopath CGO_ENABLED: 0 + TIMESCALE_FACTOR: 20 matrix: - GOARCH: 386 - GOARCH: amd64 diff --git a/h2quic/response_writer_test.go b/h2quic/response_writer_test.go index f8217f316..93ebae162 100644 --- a/h2quic/response_writer_test.go +++ b/h2quic/response_writer_test.go @@ -5,6 +5,7 @@ import ( "io" "net/http" "sync" + "time" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" @@ -36,6 +37,9 @@ func (s *mockStream) Close() error { s.closed = true; r func (s *mockStream) Reset(error) { s.reset = true } func (s *mockStream) CloseRemote(offset protocol.ByteCount) { s.remoteClosed = true } func (s mockStream) StreamID() protocol.StreamID { return s.id } +func (s *mockStream) SetDeadline(time.Time) error { panic("not implemented") } +func (s *mockStream) SetReadDeadline(time.Time) error { panic("not implemented") } +func (s *mockStream) SetWriteDeadline(time.Time) error { panic("not implemented") } func (s *mockStream) Read(p []byte) (int, error) { n, _ := s.dataToRead.Read(p) diff --git a/interface.go b/interface.go index 4ba4cb42b..685a1ab79 100644 --- a/interface.go +++ b/interface.go @@ -16,6 +16,20 @@ type Stream interface { StreamID() protocol.StreamID // Reset closes the stream with an error. Reset(error) + // SetReadDeadline sets the deadline for future Read calls and + // any currently-blocked Read call. + // A zero value for t means Read will not time out. + SetReadDeadline(t time.Time) error + // SetWriteDeadline sets the deadline for future Write calls + // and any currently-blocked Write call. + // Even if write times out, it may return n > 0, indicating that + // some of the data was successfully written. + // A zero value for t means Write will not time out. + SetWriteDeadline(t time.Time) error + // SetDeadline sets the read and write deadlines associated + // with the connection. It is equivalent to calling both + // SetReadDeadline and SetWriteDeadline. + SetDeadline(t time.Time) error } // A Session is a QUIC connection between two peers. diff --git a/stream.go b/stream.go index 183db9cbc..f78f815db 100644 --- a/stream.go +++ b/stream.go @@ -1,9 +1,11 @@ package quic import ( + "errors" "fmt" "io" "sync" + "time" "github.com/lucas-clemente/quic-go/flowcontrol" "github.com/lucas-clemente/quic-go/frames" @@ -40,32 +42,35 @@ type stream struct { // resetRemotely is set if RegisterRemoteError() is called resetRemotely utils.AtomicBool - frameQueue *streamFrameSorter - newFrameOrErrCond sync.Cond + frameQueue *streamFrameSorter + readChan chan struct{} + readDeadline time.Time - dataForWriting []byte - finSent utils.AtomicBool - rstSent utils.AtomicBool - doneWritingOrErrCond sync.Cond + dataForWriting []byte + finSent utils.AtomicBool + rstSent utils.AtomicBool + writeChan chan struct{} + writeDeadline time.Time flowControlManager flowcontrol.FlowControlManager } +var errDeadline = errors.New("deadline exceeded") + // newStream creates a new Stream func newStream(StreamID protocol.StreamID, onData func(), onReset func(protocol.StreamID, protocol.ByteCount), flowControlManager flowcontrol.FlowControlManager) *stream { - s := &stream{ + return &stream{ onData: onData, onReset: onReset, streamID: StreamID, flowControlManager: flowControlManager, frameQueue: newStreamFrameSorter(), + readChan: make(chan struct{}, 1), + writeChan: make(chan struct{}, 1), } - s.newFrameOrErrCond.L = &s.mutex - s.doneWritingOrErrCond.L = &s.mutex - return s } // Read implements io.Reader. It is not thread safe! @@ -84,10 +89,10 @@ func (s *stream) Read(p []byte) (int, error) { for bytesRead < len(p) { s.mutex.Lock() frame := s.frameQueue.Head() - if frame == nil && bytesRead > 0 { + err = s.err s.mutex.Unlock() - return bytesRead, s.err + return bytesRead, err } var err error @@ -97,11 +102,28 @@ func (s *stream) Read(p []byte) (int, error) { err = s.err break } + + deadline := s.readDeadline + if !deadline.IsZero() && !time.Now().Before(deadline) { + err = errDeadline + break + } + if frame != nil { s.readPosInFrame = int(s.readOffset - frame.Offset) break } - s.newFrameOrErrCond.Wait() + + s.mutex.Unlock() + if deadline.IsZero() { + <-s.readChan + } else { + select { + case <-s.readChan: + case <-time.After(deadline.Sub(time.Now())): + } + } + s.mutex.Lock() frame = s.frameQueue.Head() } s.mutex.Unlock() @@ -146,34 +168,49 @@ func (s *stream) Read(p []byte) (int, error) { } func (s *stream) Write(p []byte) (int, error) { - if s.resetLocally.Get() { - return 0, s.err - } - s.mutex.Lock() defer s.mutex.Unlock() - if s.err != nil { + if s.resetLocally.Get() || s.err != nil { return 0, s.err } - if len(p) == 0 { return 0, nil } s.dataForWriting = make([]byte, len(p)) copy(s.dataForWriting, p) - s.onData() - for s.dataForWriting != nil && s.err == nil { - s.doneWritingOrErrCond.Wait() + var err error + for { + deadline := s.writeDeadline + if !deadline.IsZero() && !time.Now().Before(deadline) { + err = errDeadline + break + } + if s.dataForWriting == nil || s.err != nil { + break + } + + s.mutex.Unlock() + if deadline.IsZero() { + <-s.writeChan + } else { + select { + case <-s.writeChan: + case <-time.After(deadline.Sub(time.Now())): + } + } + s.mutex.Lock() } + if err != nil { + return 0, err + } if s.err != nil { return len(p) - len(s.dataForWriting), s.err } - return len(p), nil } @@ -189,14 +226,12 @@ func (s *stream) lenOfDataForWriting() protocol.ByteCount { func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { s.mutex.Lock() - if s.err != nil { - s.mutex.Unlock() - return nil - } - if s.dataForWriting == nil { - s.mutex.Unlock() + defer s.mutex.Unlock() + + if s.err != nil || s.dataForWriting == nil { return nil } + var ret []byte if protocol.ByteCount(len(s.dataForWriting)) > maxBytes { ret = s.dataForWriting[:maxBytes] @@ -204,10 +239,9 @@ func (s *stream) getDataForWriting(maxBytes protocol.ByteCount) []byte { } else { ret = s.dataForWriting s.dataForWriting = nil - s.doneWritingOrErrCond.Signal() + s.signalWrite() } s.writeOffset += protocol.ByteCount(len(ret)) - s.mutex.Unlock() return ret } @@ -250,7 +284,52 @@ func (s *stream) AddStreamFrame(frame *frames.StreamFrame) error { if err != nil && err != errDuplicateStreamData { return err } - s.newFrameOrErrCond.Signal() + s.signalRead() + return nil +} + +// signalRead performs a non-blocking send on the readChan +func (s *stream) signalRead() { + select { + case s.readChan <- struct{}{}: + default: + } +} + +// signalRead performs a non-blocking send on the writeChan +func (s *stream) signalWrite() { + select { + case s.writeChan <- struct{}{}: + default: + } +} + +func (s *stream) SetReadDeadline(t time.Time) error { + s.mutex.Lock() + oldDeadline := s.readDeadline + s.readDeadline = t + s.mutex.Unlock() + // if the new deadline is before the currently set deadline, wake up Read() + if t.Before(oldDeadline) { + s.signalRead() + } + return nil +} + +func (s *stream) SetWriteDeadline(t time.Time) error { + s.mutex.Lock() + oldDeadline := s.writeDeadline + s.writeDeadline = t + s.mutex.Unlock() + if t.Before(oldDeadline) { + s.signalWrite() + } + return nil +} + +func (s *stream) SetDeadline(t time.Time) error { + _ = s.SetReadDeadline(t) // SetReadDeadline never errors + _ = s.SetWriteDeadline(t) // SetWriteDeadline never errors return nil } @@ -267,8 +346,8 @@ func (s *stream) Cancel(err error) { // errors must not be changed! if s.err == nil { s.err = err - s.newFrameOrErrCond.Signal() - s.doneWritingOrErrCond.Signal() + s.signalRead() + s.signalWrite() } s.mutex.Unlock() } @@ -283,8 +362,8 @@ func (s *stream) Reset(err error) { // errors must not be changed! if s.err == nil { s.err = err - s.newFrameOrErrCond.Signal() - s.doneWritingOrErrCond.Signal() + s.signalRead() + s.signalWrite() } if s.shouldSendReset() { s.onReset(s.streamID, s.writeOffset) @@ -303,7 +382,7 @@ func (s *stream) RegisterRemoteError(err error) { // errors must not be changed! if s.err == nil { s.err = err - s.doneWritingOrErrCond.Signal() + s.signalWrite() } if s.shouldSendReset() { s.onReset(s.streamID, s.writeOffset) diff --git a/stream_test.go b/stream_test.go index ad066f9ca..0d1bd6e10 100644 --- a/stream_test.go +++ b/stream_test.go @@ -3,8 +3,12 @@ package quic import ( "errors" "io" + "runtime" + "strconv" "time" + "os" + "github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/internal/mocks/mocks_fc" "github.com/lucas-clemente/quic-go/protocol" @@ -26,6 +30,18 @@ var _ = Describe("Stream", func() { mockFcm *mocks_fc.MockFlowControlManager ) + // in the tests for the stream deadlines we set a deadline + // and wait to make an assertion when Read / Write was unblocked + // on the CIs, the timing is a lot less precise, so scale every duration by this factor + scaleDuration := func(t time.Duration) time.Duration { + scaleFactor := 1 + if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set + scaleFactor = f + } + Expect(scaleFactor).ToNot(BeZero()) + return time.Duration(scaleFactor) * t + } + onData := func() { onDataCalled = true } @@ -135,11 +151,9 @@ var _ = Describe("Stream", func() { mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(2)) mockFcm.EXPECT().AddBytesRead(streamID, protocol.ByteCount(2)) go func() { - frame := frames.StreamFrame{ - Offset: 0, - Data: []byte{0xDE, 0xAD}, - } - time.Sleep(time.Millisecond) + defer GinkgoRecover() + frame := frames.StreamFrame{Data: []byte{0xDE, 0xAD}} + time.Sleep(10 * time.Millisecond) err := str.AddStreamFrame(&frame) Expect(err).ToNot(HaveOccurred()) }() @@ -240,6 +254,79 @@ var _ = Describe("Stream", func() { Expect(onDataCalled).To(BeTrue()) }) + Context("deadlines", func() { + It("returns an error when Read is called after the deadline", func() { + mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(6)).AnyTimes() + f := &frames.StreamFrame{Data: []byte("foobar")} + err := str.AddStreamFrame(f) + Expect(err).ToNot(HaveOccurred()) + str.SetReadDeadline(time.Now().Add(-time.Second)) + b := make([]byte, 6) + n, err := str.Read(b) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + }) + + It("unblocks after the deadline", func() { + deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) + str.SetReadDeadline(deadline) + b := make([]byte, 6) + n, err := str.Read(b) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(10*time.Millisecond))) + }) + + It("doesn't unblock if the deadline is changed before the first one expires", func() { + deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond)) + deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond)) + str.SetReadDeadline(deadline1) + go func() { + defer GinkgoRecover() + time.Sleep(scaleDuration(20 * time.Millisecond)) + str.SetReadDeadline(deadline2) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline1)) + }() + runtime.Gosched() + b := make([]byte, 10) + n, err := str.Read(b) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) + }) + + It("unblocks earlier, when a new deadline is set", func() { + deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond)) + deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond)) + go func() { + defer GinkgoRecover() + time.Sleep(scaleDuration(10 * time.Millisecond)) + str.SetReadDeadline(deadline2) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline2)) + }() + str.SetReadDeadline(deadline1) + runtime.Gosched() + b := make([]byte, 10) + _, err := str.Read(b) + Expect(err).To(MatchError(errDeadline)) + Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(25*time.Millisecond))) + }) + + It("sets a read deadline, when SetDeadline is called", func() { + mockFcm.EXPECT().UpdateHighestReceived(streamID, protocol.ByteCount(6)).AnyTimes() + f := &frames.StreamFrame{Data: []byte("foobar")} + err := str.AddStreamFrame(f) + Expect(err).ToNot(HaveOccurred()) + str.SetDeadline(time.Now().Add(-time.Second)) + b := make([]byte, 6) + n, err := str.Read(b) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + }) + }) + Context("closing", func() { Context("with FIN bit", func() { It("returns EOFs", func() { @@ -703,6 +790,66 @@ var _ = Describe("Stream", func() { Expect(err).ToNot(HaveOccurred()) }) + Context("deadlines", func() { + It("returns an error when Write is called after the deadline", func() { + str.SetWriteDeadline(time.Now().Add(-time.Second)) + n, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + }) + + It("unblocks after the deadline", func() { + deadline := time.Now().Add(scaleDuration(50 * time.Millisecond)) + str.SetWriteDeadline(deadline) + n, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + Expect(time.Now()).To(BeTemporally("~", deadline, scaleDuration(20*time.Millisecond))) + }) + + It("doesn't unblock if the deadline is changed before the first one expires", func() { + deadline1 := time.Now().Add(scaleDuration(50 * time.Millisecond)) + deadline2 := time.Now().Add(scaleDuration(100 * time.Millisecond)) + str.SetWriteDeadline(deadline1) + go func() { + defer GinkgoRecover() + time.Sleep(scaleDuration(20 * time.Millisecond)) + str.SetWriteDeadline(deadline2) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline1)) + }() + runtime.Gosched() + n, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) + }) + + It("unblocks earlier, when a new deadline is set", func() { + deadline1 := time.Now().Add(scaleDuration(200 * time.Millisecond)) + deadline2 := time.Now().Add(scaleDuration(50 * time.Millisecond)) + go func() { + defer GinkgoRecover() + time.Sleep(scaleDuration(10 * time.Millisecond)) + str.SetWriteDeadline(deadline2) + // make sure that this was actually execute before the deadline expires + Expect(time.Now()).To(BeTemporally("<", deadline2)) + }() + str.SetWriteDeadline(deadline1) + runtime.Gosched() + _, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + Expect(time.Now()).To(BeTemporally("~", deadline2, scaleDuration(20*time.Millisecond))) + }) + + It("sets a read deadline, when SetDeadline is called", func() { + str.SetDeadline(time.Now().Add(-time.Second)) + n, err := str.Write([]byte("foobar")) + Expect(err).To(MatchError(errDeadline)) + Expect(n).To(BeZero()) + }) + }) + Context("closing", func() { It("sets finishedWriting when calling Close", func() { str.Close()