forked from quic-go/quic-go
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
type streamCreator interface {
|
||||
GetOrOpenStream(protocol.StreamID) (utils.Stream, error)
|
||||
Close(error, bool) error
|
||||
Close(error) error
|
||||
}
|
||||
|
||||
// Server is a HTTP2 server listening for QUIC connections
|
||||
@@ -111,7 +111,7 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
|
||||
}
|
||||
if s.CloseAfterFirstRequest {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
session.Close(nil, true)
|
||||
session.Close(nil)
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ func (s *mockSession) GetOrOpenStream(id protocol.StreamID) (utils.Stream, error
|
||||
return &mockStream{}, nil
|
||||
}
|
||||
|
||||
func (s *mockSession) Close(error, bool) error { s.closed = true; return nil }
|
||||
func (s *mockSession) Close(error) error { s.closed = true; return nil }
|
||||
|
||||
var _ = Describe("H2 server", func() {
|
||||
var (
|
||||
|
||||
43
session.go
43
session.go
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/ackhandler"
|
||||
@@ -58,7 +59,7 @@ type Session struct {
|
||||
receivedPackets chan receivedPacket
|
||||
sendingScheduled chan struct{}
|
||||
closeChan chan struct{}
|
||||
closed bool
|
||||
closed uint32 // atomic bool
|
||||
|
||||
undecryptablePackets []receivedPacket
|
||||
aeadChanged chan struct{}
|
||||
@@ -102,7 +103,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
||||
|
||||
go func() {
|
||||
if err := cryptoSetup.HandleCryptoStream(); err != nil {
|
||||
session.Close(err, true)
|
||||
session.Close(err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -160,22 +161,20 @@ func (s *Session) run() {
|
||||
case <-s.aeadChanged:
|
||||
s.tryDecryptingQueuedPackets()
|
||||
case <-time.After(s.connectionParametersManager.GetIdleConnectionStateLifetime()):
|
||||
s.Close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."), true)
|
||||
s.Close(qerr.Error(qerr.NetworkIdleTimeout, "No recent network activity."))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
switch err {
|
||||
// Can happen e.g. when packets thought missing arrive late
|
||||
case ackhandler.ErrDuplicateOrOutOfOrderAck:
|
||||
// Can happen when RST_STREAMs arrive early or late (?)
|
||||
case ackhandler.ErrMapAccess:
|
||||
s.Close(err, true) // TODO: sent correct error code here
|
||||
// Can happen e.g. when packets thought missing arrive late
|
||||
case errRstStreamOnInvalidStream:
|
||||
// Can happen when RST_STREAMs arrive early or late (?)
|
||||
utils.Errorf("Ignoring error in session: %s", err.Error())
|
||||
// Can happen when we already sent the last StreamFrame with the FinBit, but the client already sent a WindowUpdate for this Stream
|
||||
case errWindowUpdateOnClosedStream:
|
||||
// Can happen when we already sent the last StreamFrame with the FinBit, but the client already sent a WindowUpdate for this Stream
|
||||
default:
|
||||
s.Close(err, true)
|
||||
s.Close(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,9 +215,8 @@ func (s *Session) handlePacketImpl(remoteAddr interface{}, hdr *publicHeader, da
|
||||
case *frames.AckFrame:
|
||||
err = s.handleAckFrame(frame)
|
||||
case *frames.ConnectionCloseFrame:
|
||||
// ToDo: send right error in ConnectionClose frame
|
||||
utils.Debugf("\t<- %#v", frame)
|
||||
s.Close(nil, false)
|
||||
s.closeImpl(qerr.Error(frame.ErrorCode, frame.ReasonPhrase), true)
|
||||
case *frames.StopWaitingFrame:
|
||||
utils.Debugf("\t<- %#v", frame)
|
||||
err = s.receivedPacketHandler.ReceivedStopWaiting(frame)
|
||||
@@ -348,25 +346,28 @@ func (s *Session) handleAckFrame(frame *frames.AckFrame) error {
|
||||
}
|
||||
|
||||
// Close the connection
|
||||
func (s *Session) Close(e error, sendConnectionClose bool) error {
|
||||
if s.closed {
|
||||
func (s *Session) Close(e error) error {
|
||||
return s.closeImpl(e, false)
|
||||
}
|
||||
|
||||
func (s *Session) closeImpl(e error, remoteClose bool) error {
|
||||
// Only close once
|
||||
if !atomic.CompareAndSwapUint32(&s.closed, 0, 1) {
|
||||
return nil
|
||||
}
|
||||
s.closed = true
|
||||
s.closeChan <- struct{}{}
|
||||
|
||||
s.closeCallback(s.connectionID)
|
||||
|
||||
if !sendConnectionClose {
|
||||
return nil
|
||||
}
|
||||
|
||||
if e == nil {
|
||||
e = qerr.PeerGoingAway
|
||||
}
|
||||
|
||||
utils.Errorf("Closing session with error: %s", e.Error())
|
||||
s.closeStreamsWithError(e)
|
||||
s.closeCallback(s.connectionID)
|
||||
|
||||
if remoteClose {
|
||||
return nil
|
||||
}
|
||||
|
||||
quicErr := qerr.ToQuicError(e)
|
||||
if quicErr.ErrorCode == qerr.DecryptionFailure {
|
||||
@@ -622,7 +623,7 @@ func (s *Session) congestionAllowsSending() bool {
|
||||
func (s *Session) tryQueueingUndecryptablePacket(p receivedPacket) {
|
||||
utils.Debugf("Queueing packet 0x%x for later decryption", p.publicHeader.PacketNumber)
|
||||
if len(s.undecryptablePackets)+1 >= protocol.MaxUndecryptablePackets {
|
||||
s.Close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"), true)
|
||||
s.Close(qerr.Error(qerr.DecryptionFailure, "too many undecryptable packets received"))
|
||||
}
|
||||
s.undecryptablePackets = append(s.undecryptablePackets, p)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
@@ -328,16 +329,25 @@ var _ = Describe("Session", func() {
|
||||
})
|
||||
|
||||
It("shuts down without error", func() {
|
||||
session.Close(nil, true)
|
||||
session.Close(nil)
|
||||
Expect(closed).To(BeTrue())
|
||||
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
|
||||
Expect(conn.written).To(HaveLen(1))
|
||||
Expect(conn.written[0][len(conn.written[0])-7:]).To(Equal([]byte{0x02, byte(qerr.PeerGoingAway), 0, 0, 0, 0, 0}))
|
||||
})
|
||||
|
||||
It("only closes once", func() {
|
||||
session.Close(nil)
|
||||
session.Close(nil)
|
||||
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
|
||||
Expect(conn.written).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("closes streams with proper error", func() {
|
||||
testErr := errors.New("test error")
|
||||
s, err := session.OpenStream(5)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
session.Close(testErr, true)
|
||||
session.Close(testErr)
|
||||
Expect(closed).To(BeTrue())
|
||||
Eventually(func() int { return runtime.NumGoroutine() }).Should(Equal(nGoRoutinesBefore))
|
||||
n, err := s.Read([]byte{0})
|
||||
@@ -554,7 +564,7 @@ var _ = Describe("Session", func() {
|
||||
Data: []byte("4242\x00\x00\x00\x00"),
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Eventually(func() bool { return session.closed }).Should(BeTrue())
|
||||
Eventually(func() bool { return atomic.LoadUint32(&session.closed) != 0 }).Should(BeTrue())
|
||||
_, err = s.Write([]byte{})
|
||||
Expect(err).To(MatchError(qerr.InvalidCryptoMessageType))
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user