forked from quic-go/quic-go
replace Session.WaitUntilClosed() by a context
This commit is contained in:
@@ -2,6 +2,7 @@ package h2quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
@@ -61,7 +62,7 @@ func (s *mockSession) LocalAddr() net.Addr {
|
||||
func (s *mockSession) RemoteAddr() net.Addr {
|
||||
return &net.UDPAddr{IP: []byte{127, 0, 0, 1}, Port: 42}
|
||||
}
|
||||
func (s *mockSession) WaitUntilClosed() { panic("not implemented") }
|
||||
func (s *mockSession) Context() context.Context { panic("not implemented") }
|
||||
|
||||
var _ = Describe("H2 server", func() {
|
||||
var (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
@@ -56,9 +57,9 @@ type Session interface {
|
||||
RemoteAddr() net.Addr
|
||||
// Close closes the connection. The error will be sent to the remote peer in a CONNECTION_CLOSE frame. An error value of nil is allowed and will cause a normal PeerGoingAway to be sent.
|
||||
Close(error) error
|
||||
// WaitUntilClosed() blocks until the session is closed.
|
||||
// The context is cancelled when the session is closed.
|
||||
// Warning: This API should not be considered stable and might change soon.
|
||||
WaitUntilClosed()
|
||||
Context() context.Context
|
||||
}
|
||||
|
||||
// A NonFWSession is a QUIC connection between two peers half-way through the handshake.
|
||||
|
||||
@@ -2,6 +2,7 @@ package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
@@ -40,9 +41,6 @@ func (s *mockSession) run() error {
|
||||
func (s *mockSession) WaitUntilHandshakeComplete() error {
|
||||
return <-s.handshakeComplete
|
||||
}
|
||||
func (*mockSession) WaitUntilClosed() {
|
||||
panic("not implemented")
|
||||
}
|
||||
func (s *mockSession) Close(e error) error {
|
||||
if s.closed {
|
||||
return nil
|
||||
@@ -58,21 +56,14 @@ func (s *mockSession) closeRemote(e error) {
|
||||
s.closedRemote = true
|
||||
close(s.stopRunLoop)
|
||||
}
|
||||
func (s *mockSession) AcceptStream() (Stream, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
func (s *mockSession) OpenStream() (Stream, error) {
|
||||
return &stream{streamID: 1337}, nil
|
||||
}
|
||||
func (s *mockSession) OpenStreamSync() (Stream, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
func (s *mockSession) LocalAddr() net.Addr {
|
||||
panic("not implemented")
|
||||
}
|
||||
func (s *mockSession) RemoteAddr() net.Addr {
|
||||
panic("not implemented")
|
||||
}
|
||||
func (s *mockSession) AcceptStream() (Stream, error) { panic("not implemented") }
|
||||
func (s *mockSession) OpenStreamSync() (Stream, error) { panic("not implemented") }
|
||||
func (s *mockSession) LocalAddr() net.Addr { panic("not implemented") }
|
||||
func (s *mockSession) RemoteAddr() net.Addr { panic("not implemented") }
|
||||
func (*mockSession) Context() context.Context { panic("not implemented") }
|
||||
|
||||
var _ Session = &mockSession{}
|
||||
var _ NonFWSession = &mockSession{}
|
||||
|
||||
17
session.go
17
session.go
@@ -1,6 +1,7 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -78,11 +79,11 @@ type session struct {
|
||||
sendingScheduled chan struct{}
|
||||
// closeChan is used to notify the run loop that it should terminate.
|
||||
closeChan chan closeError
|
||||
// runClosed is closed once the run loop exits
|
||||
// it is used to block Close() and WaitUntilClosed()
|
||||
runClosed chan struct{}
|
||||
closeOnce sync.Once
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
|
||||
// when we receive too many undecryptable packets during the handshake, we send a Public reset
|
||||
// but only after a time of protocol.PublicResetTimeout has passed
|
||||
undecryptablePackets []*receivedPacket
|
||||
@@ -167,12 +168,12 @@ func (s *session) setup(
|
||||
s.aeadChanged = aeadChanged
|
||||
handshakeChan := make(chan handshakeEvent, 3)
|
||||
s.handshakeChan = handshakeChan
|
||||
s.runClosed = make(chan struct{})
|
||||
s.handshakeCompleteChan = make(chan error, 1)
|
||||
s.receivedPackets = make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets)
|
||||
s.closeChan = make(chan closeError, 1)
|
||||
s.sendingScheduled = make(chan struct{}, 1)
|
||||
s.undecryptablePackets = make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets)
|
||||
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
|
||||
|
||||
s.timer = utils.NewTimer()
|
||||
now := time.Now()
|
||||
@@ -333,12 +334,12 @@ runLoop:
|
||||
s.handshakeChan <- handshakeEvent{err: closeErr.err}
|
||||
}
|
||||
s.handleCloseError(closeErr)
|
||||
close(s.runClosed)
|
||||
defer s.ctxCancel()
|
||||
return closeErr.err
|
||||
}
|
||||
|
||||
func (s *session) WaitUntilClosed() {
|
||||
<-s.runClosed
|
||||
func (s *session) Context() context.Context {
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
func (s *session) maybeResetTimer() {
|
||||
@@ -543,7 +544,7 @@ func (s *session) closeRemote(e error) {
|
||||
// It waits until the run loop has stopped before returning
|
||||
func (s *session) Close(e error) error {
|
||||
s.closeLocal(e)
|
||||
<-s.runClosed
|
||||
<-s.ctx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
@@ -648,7 +649,7 @@ var _ = Describe("Session", func() {
|
||||
str, _ := sess.GetOrOpenStream(5)
|
||||
err := sess.handleFrames([]frames.Frame{&frames.ConnectionCloseFrame{ErrorCode: 42, ReasonPhrase: "foobar"}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(sess.runClosed).Should(BeClosed())
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
_, err = str.Read([]byte{0})
|
||||
Expect(err).To(MatchError(qerr.Error(42, "foobar")))
|
||||
close(done)
|
||||
@@ -744,11 +745,11 @@ var _ = Describe("Session", func() {
|
||||
}()
|
||||
go sess.run()
|
||||
Consistently(func() error { return err }).ShouldNot(HaveOccurred())
|
||||
Expect(sess.runClosed).ToNot(BeClosed())
|
||||
Expect(sess.Context().Done()).ToNot(BeClosed())
|
||||
sess.Close(errCloseSessionForNewVersion)
|
||||
Eventually(func() error { return err }).Should(HaveOccurred())
|
||||
Expect(err).To(MatchError(qerr.Error(qerr.InternalError, errCloseSessionForNewVersion.Error())))
|
||||
Eventually(sess.runClosed).Should(BeClosed())
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
})
|
||||
|
||||
@@ -764,7 +765,7 @@ var _ = Describe("Session", func() {
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
Expect(mconn.written).To(HaveLen(1))
|
||||
Expect(mconn.written[0]).To(ContainSubstring(string([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0})))
|
||||
Expect(sess.runClosed).To(BeClosed())
|
||||
Expect(sess.Context().Done()).To(BeClosed())
|
||||
})
|
||||
|
||||
It("only closes once", func() {
|
||||
@@ -772,7 +773,7 @@ var _ = Describe("Session", func() {
|
||||
sess.Close(nil)
|
||||
Eventually(areSessionsRunning).Should(BeFalse())
|
||||
Expect(mconn.written).To(HaveLen(1))
|
||||
Expect(sess.runClosed).To(BeClosed())
|
||||
Expect(sess.Context().Done()).To(BeClosed())
|
||||
})
|
||||
|
||||
It("closes streams with proper error", func() {
|
||||
@@ -787,7 +788,7 @@ var _ = Describe("Session", func() {
|
||||
n, err = s.Write([]byte{0})
|
||||
Expect(n).To(BeZero())
|
||||
Expect(err.Error()).To(ContainSubstring(testErr.Error()))
|
||||
Expect(sess.runClosed).To(BeClosed())
|
||||
Expect(sess.Context().Done()).To(BeClosed())
|
||||
})
|
||||
|
||||
It("closes the session in order to replace it with another QUIC version", func() {
|
||||
@@ -800,13 +801,16 @@ var _ = Describe("Session", func() {
|
||||
sess.Close(handshake.ErrHOLExperiment)
|
||||
Expect(mconn.written).To(HaveLen(1))
|
||||
Expect(mconn.written[0][0] & 0x02).ToNot(BeZero()) // Public Reset
|
||||
Expect(sess.runClosed).To(BeClosed())
|
||||
Expect(sess.Context().Done()).To(BeClosed())
|
||||
})
|
||||
|
||||
It("unblocks WaitUntilClosed when the run loop exists", func() {
|
||||
It("cancels the context when the run loop exists", func() {
|
||||
returned := make(chan struct{})
|
||||
go func() {
|
||||
sess.WaitUntilClosed()
|
||||
defer GinkgoRecover()
|
||||
ctx := sess.Context()
|
||||
<-ctx.Done()
|
||||
Expect(ctx.Err()).To(MatchError(context.Canceled))
|
||||
close(returned)
|
||||
}()
|
||||
Consistently(returned).ShouldNot(BeClosed())
|
||||
@@ -841,7 +845,7 @@ var _ = Describe("Session", func() {
|
||||
sess.unpacker.(*mockUnpacker).unpackErr = testErr
|
||||
sess.handlePacket(&receivedPacket{publicHeader: hdr})
|
||||
Eventually(func() error { return runErr }).Should(MatchError(testErr))
|
||||
Expect(sess.runClosed).To(BeClosed())
|
||||
Expect(sess.Context().Done()).To(BeClosed())
|
||||
close(done)
|
||||
})
|
||||
|
||||
@@ -1362,7 +1366,7 @@ var _ = Describe("Session", func() {
|
||||
sess.scheduleSending() // wake up the run loop
|
||||
Eventually(func() [][]byte { return mconn.written }).Should(HaveLen(1))
|
||||
Expect(mconn.written[0]).To(ContainSubstring("PRST"))
|
||||
Eventually(sess.runClosed).Should(BeClosed())
|
||||
Eventually(sess.Context().Done()).Should(BeClosed())
|
||||
})
|
||||
|
||||
It("doesn't send a Public Reset if decrypting them suceeded during the timeout", func() {
|
||||
@@ -1372,7 +1376,7 @@ var _ = Describe("Session", func() {
|
||||
// there are no packets in the undecryptable packet queue
|
||||
// in reality, this happens when the trial decryption succeeded during the Public Reset timeout
|
||||
Consistently(func() [][]byte { return mconn.written }).ShouldNot(HaveLen(1))
|
||||
Expect(sess.runClosed).ToNot(Receive())
|
||||
Expect(sess.Context().Done()).ToNot(Receive())
|
||||
sess.Close(nil)
|
||||
})
|
||||
|
||||
@@ -1477,7 +1481,7 @@ var _ = Describe("Session", func() {
|
||||
err := sess.run() // Would normally not return
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout))
|
||||
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
|
||||
Expect(sess.runClosed).To(BeClosed())
|
||||
Expect(sess.Context().Done()).To(BeClosed())
|
||||
close(done)
|
||||
})
|
||||
|
||||
@@ -1486,7 +1490,7 @@ var _ = Describe("Session", func() {
|
||||
err := sess.run() // Would normally not return
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.HandshakeTimeout))
|
||||
Expect(mconn.written[0]).To(ContainSubstring("Crypto handshake did not complete in time."))
|
||||
Expect(sess.runClosed).To(BeClosed())
|
||||
Expect(sess.Context().Done()).To(BeClosed())
|
||||
close(done)
|
||||
})
|
||||
|
||||
@@ -1500,7 +1504,7 @@ var _ = Describe("Session", func() {
|
||||
err := sess.run() // Would normally not return
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout))
|
||||
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
|
||||
Expect(sess.runClosed).To(BeClosed())
|
||||
Expect(sess.Context().Done()).To(BeClosed())
|
||||
close(done)
|
||||
})
|
||||
|
||||
@@ -1515,7 +1519,7 @@ var _ = Describe("Session", func() {
|
||||
err := sess.run() // Would normally not return
|
||||
Expect(err.(*qerr.QuicError).ErrorCode).To(Equal(qerr.NetworkIdleTimeout))
|
||||
Expect(mconn.written[0]).To(ContainSubstring("No recent network activity."))
|
||||
Expect(sess.runClosed).To(BeClosed())
|
||||
Expect(sess.Context().Done()).To(BeClosed())
|
||||
close(done)
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user