replace Session.WaitUntilClosed() by a context

This commit is contained in:
Marten Seemann
2017-07-26 16:44:17 +07:00
parent 811315e31a
commit e02f5d5fbe
5 changed files with 41 additions and 43 deletions

View File

@@ -2,6 +2,7 @@ package h2quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
"io"
@@ -61,7 +62,7 @@ func (s *mockSession) LocalAddr() net.Addr {
func (s *mockSession) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42}
}
func (s *mockSession) WaitUntilClosed() { panic("not implemented") }
func (s *mockSession) Context() context.Context { panic("not implemented") }
var _ = Describe("H2 server", func() {
var (

View File

@@ -1,6 +1,7 @@
package quic
import (
"context"
"io"
"net"
"time"
@@ -56,9 +57,9 @@ type Session interface {
RemoteAddr() net.Addr
// Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent.
Close(error) error
// WaitUntilClosed() blocks until the session is closed.
// The context is cancelled when the session is closed.
// Warning: This API should not be considered stable and might change soon.
WaitUntilClosed()
Context() context.Context
}
// A NonFWSession is a QUIC connection between two peers half-way through the handshake.

View File

@@ -2,6 +2,7 @@ package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
"net"
@@ -40,9 +41,6 @@ func (s *mockSession) run() error {
func (s *mockSession) WaitUntilHandshakeComplete() error {
return <-s.handshakeComplete
}
func (*mockSession) WaitUntilClosed() {
panic("not implemented")
}
func (s *mockSession) Close(e error) error {
if s.closed {
return nil
@@ -58,21 +56,14 @@ func (s *mockSession) closeRemote(e error) {
s.closedRemote = true
close(s.stopRunLoop)
}
func (s *mockSession) AcceptStream() (Stream, error) {
panic("not implemented")
}
func (s *mockSession) OpenStream() (Stream, error) {
return &stream{streamID: 1337}, nil
}
func (s *mockSession) OpenStreamSync() (Stream, error) {
panic("not implemented")
}
func (s *mockSession) LocalAddr() net.Addr {
panic("not implemented")
}
func (s *mockSession) RemoteAddr() net.Addr {
panic("not implemented")
}
func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") }
func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") }
func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") }
func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") }
func (*mockSession) Context() context.Context { panic("not implemented") }
var _ Session = &mockSession{}
var _ NonFWSession = &mockSession{}

View File

@@ -1,6 +1,7 @@
package quic
import (
"context"
"crypto/tls"
"errors"
"fmt"
@@ -78,11 +79,11 @@ type session struct {
sendingScheduled chan struct{}
// closeChan is used to notify the run loop that it should terminate.
closeChan chan closeError
// runClosed is closed once the run loop exits
// it is used to block Close() and WaitUntilClosed()
runClosed chan struct{}
closeOnce sync.Once
ctx context.Context
ctxCancel context.CancelFunc
// when we receive too many undecryptable packets during the handshake, we send a Public reset
// but only after a time of protocol.PublicResetTimeout has passed
undecryptablePackets []*receivedPacket
@@ -167,12 +168,12 @@ func (s *session) setup(
s.aeadChanged = aeadChanged
handshakeChan := make(chan handshakeEvent, 3)
s.handshakeChan = handshakeChan
s.runClosed = make(chan struct{})
s.handshakeCompleteChan = make(chan error, 1)
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1)
s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets)
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
s.timer = utils.NewTimer()
now := time.Now()
@@ -333,12 +334,12 @@ runLoop:
s.handshakeChan <- handshakeEvent{err: closeErr.err}
}
s.handleCloseError(closeErr)
close(s.runClosed)
defer s.ctxCancel()
return closeErr.err
}
func (s *session) WaitUntilClosed() {
<-s.runClosed
func (s *session) Context() context.Context {
return s.ctx
}
func (s *session) maybeResetTimer() {
@@ -543,7 +544,7 @@ func (s *session) closeRemote(e error) {
// It waits until the run loop has stopped before returning
func (s *session) Close(e error) error {
s.closeLocal(e)
<-s.runClosed
<-s.ctx.Done()
return nil
}

View File

@@ -2,6 +2,7 @@ package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
"io"
@@ -648,7 +649,7 @@ var _ = Describe("Session", func() {
str, _ := sess.GetOrOpenStream(5)
err := sess.handleFrames([]frames.Frame{&frames.ConnectionCloseFrame{ErrorCode: 42, ReasonPhrase: "foobar"}})
Expect(err).NotTo(HaveOccurred())
Eventually(sess.runClosed).Should(BeClosed())
Eventually(sess.Context().Done()).Should(BeClosed())
_, err = str.Read([]byte{0})
Expect(err).To(MatchError(qerr.Error(42, "foobar")))
close(done)
@@ -744,11 +745,11 @@ var _ = Describe("Session", func() {
}()
go sess.run()
Consistently(func() error { return err }).ShouldNot(HaveOccurred())
Expect(sess.runClosed).ToNot(BeClosed())
Expect(sess.Context().Done()).ToNot(BeClosed())
sess.Close(errCloseSessionForNewVersion)
Eventually(func() error { return err }).Should(HaveOccurred())
Expect(err).To(MatchError(qerr.Error(qerr.InternalError, errCloseSessionForNewVersion.Error())))
Eventually(sess.runClosed).Should(BeClosed())
Eventually(sess.Context().Done()).Should(BeClosed())
})
})
@@ -764,7 +765,7 @@ var _ = Describe("Session", func() {
Eventually(areSessionsRunning).Should(BeFalse())
Expect(mconn.written).To(HaveLen(1))
Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0})))
Expect(sess.runClosed).To(BeClosed())
Expect(sess.Context().Done()).To(BeClosed())
})
It("only closes once", func() {
@@ -772,7 +773,7 @@ var _ = Describe("Session", func() {
sess.Close(nil)
Eventually(areSessionsRunning).Should(BeFalse())
Expect(mconn.written).To(HaveLen(1))
Expect(sess.runClosed).To(BeClosed())
Expect(sess.Context().Done()).To(BeClosed())
})
It("closes streams with proper error", func() {
@@ -787,7 +788,7 @@ var _ = Describe("Session", func() {
n, err = s.Write([]byte{0})
Expect(n).To(BeZero())
Expect(err.Error()).To(ContainSubstring(testErr.Error()))
Expect(sess.runClosed).To(BeClosed())
Expect(sess.Context().Done()).To(BeClosed())
})
It("closes the session in order to replace it with another QUIC version", func() {
@@ -800,13 +801,16 @@ var _ = Describe("Session", func() {
sess.Close(handshake.ErrHOLExperiment)
Expect(mconn.written).To(HaveLen(1))
Expect(mconn.written[0][0] & 0x02).ToNot(BeZero()) // Public Reset
Expect(sess.runClosed).To(BeClosed())
Expect(sess.Context().Done()).To(BeClosed())
})
It("unblocks WaitUntilClosed when the run loop exists", func() {
It("cancels the context when the run loop exists", func() {
returned := make(chan struct{})
go func() {
sess.WaitUntilClosed()
defer GinkgoRecover()
ctx := sess.Context()
<-ctx.Done()
Expect(ctx.Err()).To(MatchError(context.Canceled))
close(returned)
}()
Consistently(returned).ShouldNot(BeClosed())
@@ -841,7 +845,7 @@ var _ = Describe("Session", func() {
sess.unpacker.(*mockUnpacker).unpackErr = testErr
sess.handlePacket(&receivedPacket{publicHeader: hdr})
Eventually(func() error { return runErr }).Should(MatchError(testErr))
Expect(sess.runClosed).To(BeClosed())
Expect(sess.Context().Done()).To(BeClosed())
close(done)
})
@@ -1362,7 +1366,7 @@ var _ = Describe("Session", func() {
sess.scheduleSending() // wake up the run loop
Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(1))
Expect(mconn.written[0]).To(ContainSubstring("PRST"))
Eventually(sess.runClosed).Should(BeClosed())
Eventually(sess.Context().Done()).Should(BeClosed())
})
It("doesn't send a Public Reset if decrypting them suceeded during the timeout", func() {
@@ -1372,7 +1376,7 @@ var _ = Describe("Session", func() {
// there are no packets in the undecryptable packet queue
// in reality, this happens when the trial decryption succeeded during the Public Reset timeout
Consistently(func() [][]byte { return mconn.written }).ShouldNot(HaveLen(1))
Expect(sess.runClosed).ToNot(Receive())
Expect(sess.Context().Done()).ToNot(Receive())
sess.Close(nil)
})
@@ -1477,7 +1481,7 @@ var _ = Describe("Session", func() {
err := sess.run() // Would normally not return
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout))
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(sess.runClosed).To(BeClosed())
Expect(sess.Context().Done()).To(BeClosed())
close(done)
})
@@ -1486,7 +1490,7 @@ var _ = Describe("Session", func() {
err := sess.run() // Would normally not return
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
Expect(mconn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time."))
Expect(sess.runClosed).To(BeClosed())
Expect(sess.Context().Done()).To(BeClosed())
close(done)
})
@@ -1500,7 +1504,7 @@ var _ = Describe("Session", func() {
err := sess.run() // Would normally not return
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout))
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(sess.runClosed).To(BeClosed())
Expect(sess.Context().Done()).To(BeClosed())
close(done)
})
@@ -1515,7 +1519,7 @@ var _ = Describe("Session", func() {
err := sess.run() // Would normally not return
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout))
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
Expect(sess.runClosed).To(BeClosed())
Expect(sess.Context().Done()).To(BeClosed())
close(done)
})
})