refactor session to remove second Close parameter

fixes #102
This commit is contained in:
Lucas Clemente
2016-05-17 14:56:19 +02:00
parent 42f3091e1b
commit 68b529a54c
4 changed files with 38 additions and 27 deletions

View File

@@ -16,7 +16,7 @@ import (
type streamCreator interface {
GetOrOpenStream(protocol.StreamID) (utils.Stream, error)
Close(error, bool) error
Close(error) error
}
// Server is a HTTP2 server listening for QUIC connections
@@ -111,7 +111,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
}
if s.CloseAfterFirstRequest {
time.Sleep(100 * time.Millisecond)
session.Close(nil, true)
session.Close(nil)
}
}()

View File

@@ -22,7 +22,7 @@ func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error
return &mockStream{}, nil
}
func (s *mockSession) Close(error, bool) error { s.closed = true; return nil }
func (s *mockSession) Close(error) error { s.closed = true; return nil }
var _ = Describe("H2 server", func() {
var (

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/lucas-clemente/quic-go/ackhandler"
@@ -58,7 +59,7 @@ type Session struct {
receivedPackets chan receivedPacket
sendingScheduled chan struct{}
closeChan chan struct{}
closed bool
closed uint32 // atomic bool
undecryptablePackets []receivedPacket
aeadChanged chan struct{}
@@ -102,7 +103,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
go func() {
if err := cryptoSetup.HandleCryptoStream(); err != nil {
session.Close(err, true)
session.Close(err)
}
}()
@@ -160,22 +161,20 @@ func (s *Session) run() {
case <-s.aeadChanged:
s.tryDecryptingQueuedPackets()
case <-time.After(s.connectionParametersManager.GetIdleConnectionStateLifetime()):
s.Close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."), true)
s.Close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
}
if err != nil {
switch err {
// Can happen e.g. when packets thought missing arrive late
case ackhandler.ErrDuplicateOrOutOfOrderAck:
// Can happen when RST_STREAMs arrive early or late (?)
case ackhandler.ErrMapAccess:
s.Close(err, true) // TODO: sent correct error code here
// Can happen e.g. when packets thought missing arrive late
case errRstStreamOnInvalidStream:
// Can happen when RST_STREAMs arrive early or late (?)
utils.Errorf("Ignoring error in session: %s", err.Error())
// Can happen when we already sent the last StreamFrame with the FinBit, but the client already sent a WindowUpdate for this Stream
case errWindowUpdateOnClosedStream:
// Can happen when we already sent the last StreamFrame with the FinBit, but the client already sent a WindowUpdate for this Stream
default:
s.Close(err, true)
s.Close(err)
}
}
@@ -216,9 +215,8 @@ func (s *Session) handlePacketImpl(remoteAddr interface{}, hdr *publicHeader, da
case *frames.AckFrame:
err = s.handleAckFrame(frame)
case *frames.ConnectionCloseFrame:
// ToDo: send right error in ConnectionClose frame
utils.Debugf("\t<- %#v", frame)
s.Close(nil, false)
s.closeImpl(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true)
case *frames.StopWaitingFrame:
utils.Debugf("\t<- %#v", frame)
err = s.receivedPacketHandler.ReceivedStopWaiting(frame)
@@ -348,25 +346,28 @@ func (s *Session) handleAckFrame(frame *frames.AckFrame) error {
}
// Close the connection
func (s *Session) Close(e error, sendConnectionClose bool) error {
if s.closed {
func (s *Session) Close(e error) error {
return s.closeImpl(e, false)
}
func (s *Session) closeImpl(e error, remoteClose bool) error {
// Only close once
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
return nil
}
s.closed = true
s.closeChan <- struct{}{}
s.closeCallback(s.connectionID)
if !sendConnectionClose {
return nil
}
if e == nil {
e = qerr.PeerGoingAway
}
utils.Errorf("Closing session with error: %s", e.Error())
s.closeStreamsWithError(e)
s.closeCallback(s.connectionID)
if remoteClose {
return nil
}
quicErr := qerr.ToQuicError(e)
if quicErr.ErrorCode == qerr.DecryptionFailure {
@@ -622,7 +623,7 @@ func (s *Session) congestionAllowsSending() bool {
func (s *Session) tryQueueingUndecryptablePacket(p receivedPacket) {
utils.Debugf("Queueing packet 0x%x for later decryption", p.publicHeader.PacketNumber)
if len(s.undecryptablePackets)+1 >= protocol.MaxUndecryptablePackets {
s.Close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"), true)
s.Close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"))
}
s.undecryptablePackets = append(s.undecryptablePackets, p)
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"io"
"runtime"
"sync/atomic"
"time"
. "github.com/onsi/ginkgo"
@@ -328,16 +329,25 @@ var _ = Describe("Session", func() {
})
It("shuts down without error", func() {
session.Close(nil, true)
session.Close(nil)
Expect(closed).To(BeTrue())
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
Expect(conn.written).To(HaveLen(1))
Expect(conn.written[0][len(conn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0}))
})
It("only closes once", func() {
session.Close(nil)
session.Close(nil)
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
Expect(conn.written).To(HaveLen(1))
})
It("closes streams with proper error", func() {
testErr := errors.New("test error")
s, err := session.OpenStream(5)
Expect(err).NotTo(HaveOccurred())
session.Close(testErr, true)
session.Close(testErr)
Expect(closed).To(BeTrue())
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
n, err := s.Read([]byte{0})
@@ -554,7 +564,7 @@ var _ = Describe("Session", func() {
Data: []byte("4242\x00\x00\x00\x00"),
})
Expect(err).NotTo(HaveOccurred())
Eventually(func() bool { return session.closed }).Should(BeTrue())
Eventually(func() bool { return atomic.LoadUint32(&session.closed) != 0 }).Should(BeTrue())
_, err = s.Write([]byte{})
Expect(err).To(MatchError(qerr.InvalidCryptoMessageType))
})