forked from quic-go/quic-go
http3: simplify HTTP datagram handling (#5156)
This commit is contained in:
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
var _ quic.Stream = &stateTrackingStream{}
|
||||
const streamDatagramQueueLen = 32
|
||||
|
||||
// stateTrackingStream is an implementation of quic.Stream that delegates
|
||||
// to an underlying stream
|
||||
@@ -21,28 +21,32 @@ var _ quic.Stream = &stateTrackingStream{}
|
||||
type stateTrackingStream struct {
|
||||
quic.Stream
|
||||
|
||||
sendDatagram func([]byte) error
|
||||
hasData chan struct{}
|
||||
queue [][]byte // TODO: use a ring buffer
|
||||
|
||||
mx sync.Mutex
|
||||
sendErr error
|
||||
recvErr error
|
||||
|
||||
clearer streamClearer
|
||||
setter errorSetter
|
||||
}
|
||||
|
||||
var (
|
||||
_ datagramStream = &stateTrackingStream{}
|
||||
_ quic.Stream = &stateTrackingStream{}
|
||||
)
|
||||
|
||||
type streamClearer interface {
|
||||
clearStream(quic.StreamID)
|
||||
}
|
||||
|
||||
type errorSetter interface {
|
||||
SetSendError(error)
|
||||
SetReceiveError(error)
|
||||
}
|
||||
|
||||
func newStateTrackingStream(s quic.Stream, clearer streamClearer, setter errorSetter) *stateTrackingStream {
|
||||
func newStateTrackingStream(s quic.Stream, clearer streamClearer, sendDatagram func([]byte) error) *stateTrackingStream {
|
||||
t := &stateTrackingStream{
|
||||
Stream: s,
|
||||
clearer: clearer,
|
||||
setter: setter,
|
||||
Stream: s,
|
||||
clearer: clearer,
|
||||
sendDatagram: sendDatagram,
|
||||
hasData: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
context.AfterFunc(s.Context(), func() {
|
||||
@@ -62,8 +66,6 @@ func (s *stateTrackingStream) closeSend(e error) {
|
||||
if s.recvErr != nil {
|
||||
s.clearer.clearStream(s.StreamID())
|
||||
}
|
||||
|
||||
s.setter.SetSendError(e)
|
||||
s.sendErr = e
|
||||
}
|
||||
}
|
||||
@@ -78,9 +80,8 @@ func (s *stateTrackingStream) closeReceive(e error) {
|
||||
if s.sendErr != nil {
|
||||
s.clearer.clearStream(s.StreamID())
|
||||
}
|
||||
|
||||
s.setter.SetReceiveError(e)
|
||||
s.recvErr = e
|
||||
s.signalHasDatagram()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,3 +115,58 @@ func (s *stateTrackingStream) Read(b []byte) (int, error) {
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) SendDatagram(b []byte) error {
|
||||
s.mx.Lock()
|
||||
sendErr := s.sendErr
|
||||
s.mx.Unlock()
|
||||
if sendErr != nil {
|
||||
return sendErr
|
||||
}
|
||||
|
||||
return s.sendDatagram(b)
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) signalHasDatagram() {
|
||||
select {
|
||||
case s.hasData <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) enqueueDatagram(data []byte) {
|
||||
s.mx.Lock()
|
||||
defer s.mx.Unlock()
|
||||
|
||||
if s.recvErr != nil {
|
||||
return
|
||||
}
|
||||
if len(s.queue) >= streamDatagramQueueLen {
|
||||
return
|
||||
}
|
||||
s.queue = append(s.queue, data)
|
||||
s.signalHasDatagram()
|
||||
}
|
||||
|
||||
func (s *stateTrackingStream) ReceiveDatagram(ctx context.Context) ([]byte, error) {
|
||||
start:
|
||||
s.mx.Lock()
|
||||
if len(s.queue) > 0 {
|
||||
data := s.queue[0]
|
||||
s.queue = s.queue[1:]
|
||||
s.mx.Unlock()
|
||||
return data, nil
|
||||
}
|
||||
if receiveErr := s.recvErr; receiveErr != nil {
|
||||
s.mx.Unlock()
|
||||
return nil, receiveErr
|
||||
}
|
||||
s.mx.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, context.Cause(ctx)
|
||||
case <-s.hasData:
|
||||
}
|
||||
goto start
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user