Files
quic-go/http3/state_tracking_stream.go
2025-03-30 07:16:14 +02:00

117 lines
2.6 KiB
Go

package http3
import (
"context"
"errors"
"os"
"sync"
"github.com/quic-go/quic-go"
)
var _ quic.Stream = &stateTrackingStream{}
// stateTrackingStream is an implementation of quic.Stream that delegates
// to an underlying stream
// it takes care of proxying send and receive errors onto an implementation of
// the errorSetter interface (intended to be occupied by a datagrammer)
// it is also responsible for clearing the stream based on its ID from its
// parent connection, this is done through the streamClearer interface when
// both the send and receive sides are closed
type stateTrackingStream struct {
quic.Stream
mx sync.Mutex
sendErr error
recvErr error
clearer streamClearer
setter errorSetter
}
type streamClearer interface {
clearStream(quic.StreamID)
}
type errorSetter interface {
SetSendError(error)
SetReceiveError(error)
}
func newStateTrackingStream(s quic.Stream, clearer streamClearer, setter errorSetter) *stateTrackingStream {
t := &stateTrackingStream{
Stream: s,
clearer: clearer,
setter: setter,
}
context.AfterFunc(s.Context(), func() {
t.closeSend(context.Cause(s.Context()))
})
return t
}
func (s *stateTrackingStream) closeSend(e error) {
s.mx.Lock()
defer s.mx.Unlock()
// clear the stream the first time both the send
// and receive are finished
if s.sendErr == nil {
if s.recvErr != nil {
s.clearer.clearStream(s.StreamID())
}
s.setter.SetSendError(e)
s.sendErr = e
}
}
func (s *stateTrackingStream) closeReceive(e error) {
s.mx.Lock()
defer s.mx.Unlock()
// clear the stream the first time both the send
// and receive are finished
if s.recvErr == nil {
if s.sendErr != nil {
s.clearer.clearStream(s.StreamID())
}
s.setter.SetReceiveError(e)
s.recvErr = e
}
}
func (s *stateTrackingStream) Close() error {
s.closeSend(errors.New("write on closed stream"))
return s.Stream.Close()
}
func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) {
s.closeSend(&quic.StreamError{StreamID: s.StreamID(), ErrorCode: e})
s.Stream.CancelWrite(e)
}
func (s *stateTrackingStream) Write(b []byte) (int, error) {
n, err := s.Stream.Write(b)
if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
s.closeSend(err)
}
return n, err
}
func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) {
s.closeReceive(&quic.StreamError{StreamID: s.StreamID(), ErrorCode: e})
s.Stream.CancelRead(e)
}
func (s *stateTrackingStream) Read(b []byte) (int, error) {
n, err := s.Stream.Read(b)
if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) {
s.closeReceive(err)
}
return n, err
}