forked from quic-go/quic-go
96 lines
2.1 KiB
Go
96 lines
2.1 KiB
Go
package http3
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"sync"
|
|
|
|
"github.com/quic-go/quic-go"
|
|
)
|
|
|
|
type streamState uint8
|
|
|
|
const (
|
|
streamStateOpen streamState = iota
|
|
streamStateReceiveClosed
|
|
streamStateSendClosed
|
|
streamStateSendAndReceiveClosed
|
|
)
|
|
|
|
type stateTrackingStream struct {
|
|
quic.Stream
|
|
|
|
mx sync.Mutex
|
|
state streamState
|
|
|
|
onStateChange func(streamState, error)
|
|
}
|
|
|
|
func newStateTrackingStream(s quic.Stream, onStateChange func(streamState, error)) *stateTrackingStream {
|
|
context.AfterFunc(s.Context(), func() {
|
|
onStateChange(streamStateSendClosed, context.Cause(s.Context()))
|
|
})
|
|
return &stateTrackingStream{
|
|
Stream: s,
|
|
state: streamStateOpen,
|
|
onStateChange: onStateChange,
|
|
}
|
|
}
|
|
|
|
var _ quic.Stream = &stateTrackingStream{}
|
|
|
|
func (s *stateTrackingStream) closeSend(e error) {
|
|
s.mx.Lock()
|
|
defer s.mx.Unlock()
|
|
|
|
if s.state == streamStateReceiveClosed || s.state == streamStateSendAndReceiveClosed {
|
|
s.state = streamStateSendAndReceiveClosed
|
|
} else {
|
|
s.state = streamStateSendClosed
|
|
}
|
|
s.onStateChange(s.state, e)
|
|
}
|
|
|
|
func (s *stateTrackingStream) closeReceive(e error) {
|
|
s.mx.Lock()
|
|
defer s.mx.Unlock()
|
|
|
|
if s.state == streamStateSendClosed || s.state == streamStateSendAndReceiveClosed {
|
|
s.state = streamStateSendAndReceiveClosed
|
|
} else {
|
|
s.state = streamStateReceiveClosed
|
|
}
|
|
s.onStateChange(s.state, e)
|
|
}
|
|
|
|
func (s *stateTrackingStream) Close() error {
|
|
s.closeSend(errors.New("write on closed stream"))
|
|
return s.Stream.Close()
|
|
}
|
|
|
|
func (s *stateTrackingStream) CancelWrite(e quic.StreamErrorCode) {
|
|
s.closeSend(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e})
|
|
s.Stream.CancelWrite(e)
|
|
}
|
|
|
|
func (s *stateTrackingStream) Write(b []byte) (int, error) {
|
|
n, err := s.Stream.Write(b)
|
|
if err != nil {
|
|
s.closeSend(err)
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
func (s *stateTrackingStream) CancelRead(e quic.StreamErrorCode) {
|
|
s.closeReceive(&quic.StreamError{StreamID: s.Stream.StreamID(), ErrorCode: e})
|
|
s.Stream.CancelRead(e)
|
|
}
|
|
|
|
func (s *stateTrackingStream) Read(b []byte) (int, error) {
|
|
n, err := s.Stream.Read(b)
|
|
if err != nil {
|
|
s.closeReceive(err)
|
|
}
|
|
return n, err
|
|
}
|