package http3 import ( "context" "errors" "os" "sync" "github.com/quic-go/quic-go" ) var _ quic.Stream = &stateTrackingStream{} // stateTrackingStream is an implementation of quic.Stream that delegates // to an underlying stream // it takes care of proxying send and receive errors onto an implementation of // the errorSetter interface (intended to be occupied by a datagrammer) // it is also responsible for clearing the stream based on its ID from its // parent connection, this is done through the streamClearer interface when // both the send and receive sides are closed type stateTrackingStream struct { quic.Stream mx sync.Mutex sendErr error recvErr error clearer streamClearer setter errorSetter } type streamClearer interface { clearStream(quic.StreamID) } type errorSetter interface { SetSendError(error) SetReceiveError(error) } func newStateTrackingStream(s quic.Stream, clearer streamClearer, setter errorSetter) *stateTrackingStream { t := &stateTrackingStream{ Stream: s, clearer: clearer, setter: setter, } context.AfterFunc(s.Context(), func() { t.closeSend(context.Cause(s.Context())) }) return t } func (s *stateTrackingStream) closeSend(e error) { s.mx.Lock() defer s.mx.Unlock() // clear the stream the first time both the send // and receive are finished if s.sendErr == nil { if s.recvErr != nil { s.clearer.clearStream(s.StreamID()) } s.setter.SetSendError(e) s.sendErr = e } } func (s *stateTrackingStream) closeReceive(e error) { s.mx.Lock() defer s.mx.Unlock() // clear the stream the first time both the send // and receive are finished if s.recvErr == nil { if s.sendErr != nil { s.clearer.clearStream(s.StreamID()) } s.setter.SetReceiveError(e) s.recvErr = e } } func (s *stateTrackingStream) Close() error { s.closeSend(errors.New("write on closed stream")) return s.Stream.Close() } func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) { s.closeSend(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e}) s.Stream.CancelWrite(e) } func (s *stateTrackingStream) Write(b []byte) (int, error) { n, err := s.Stream.Write(b) if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { s.closeSend(err) } return n, err } func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) { s.closeReceive(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e}) s.Stream.CancelRead(e) } func (s *stateTrackingStream) Read(b []byte) (int, error) { n, err := s.Stream.Read(b) if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { s.closeReceive(err) } return n, err }