forked from quic-go/quic-go
341 lines
10 KiB
Go
341 lines
10 KiB
Go
package http3
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/quic-go/quic-go"
|
|
mockquic "github.com/quic-go/quic-go/internal/mocks/quic"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
)
|
|
|
|
func canceledCtx() context.Context {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
return ctx
|
|
}
|
|
|
|
func checkDatagramReceive(t *testing.T, str *stateTrackingStream) {
|
|
t.Helper()
|
|
_, err := str.ReceiveDatagram(canceledCtx())
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
}
|
|
|
|
func checkDatagramSend(t *testing.T, str *stateTrackingStream) {
|
|
t.Helper()
|
|
require.NoError(t, str.SendDatagram([]byte("test")))
|
|
}
|
|
|
|
type mockStreamClearer struct {
|
|
cleared *quic.StreamID
|
|
}
|
|
|
|
func (s *mockStreamClearer) clearStream(id quic.StreamID) {
|
|
s.cleared = &id
|
|
}
|
|
|
|
func TestStateTrackingStreamRead(t *testing.T) {
|
|
t.Run("io.EOF", func(t *testing.T) {
|
|
testStateTrackingStreamRead(t, io.EOF)
|
|
})
|
|
t.Run("remote stream reset", func(t *testing.T) {
|
|
testStateTrackingStreamRead(t, assert.AnError)
|
|
})
|
|
}
|
|
|
|
func testStateTrackingStreamRead(t *testing.T, expectedErr error) {
|
|
mockCtrl := gomock.NewController(t)
|
|
qstr := mockquic.NewMockStream(mockCtrl)
|
|
qstr.EXPECT().StreamID().AnyTimes().Return(quic.StreamID(1337))
|
|
qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
|
|
|
|
var clearer mockStreamClearer
|
|
str := newStateTrackingStream(qstr, &clearer, func(b []byte) error { return nil })
|
|
|
|
// deadline errors are ignored
|
|
qstr.EXPECT().Read(gomock.Any()).Return(0, os.ErrDeadlineExceeded)
|
|
_, err := str.Read(make([]byte, 3))
|
|
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
|
|
require.Nil(t, clearer.cleared)
|
|
|
|
if expectedErr == io.EOF {
|
|
buf := bytes.NewBuffer([]byte("foobar"))
|
|
qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
|
|
|
for range 3 {
|
|
_, err := str.Read([]byte{0})
|
|
require.NoError(t, err)
|
|
require.Nil(t, clearer.cleared)
|
|
checkDatagramReceive(t, str)
|
|
}
|
|
} else {
|
|
qstr.EXPECT().Read(gomock.Any()).Return(0, expectedErr).AnyTimes()
|
|
}
|
|
|
|
_, err = io.ReadAll(str)
|
|
if expectedErr == io.EOF {
|
|
require.NoError(t, err)
|
|
} else {
|
|
require.ErrorIs(t, err, expectedErr)
|
|
}
|
|
require.Nil(t, clearer.cleared)
|
|
// the receive side registered the error
|
|
_, err = str.ReceiveDatagram(canceledCtx())
|
|
require.ErrorIs(t, err, expectedErr)
|
|
// the send side is still open
|
|
require.NoError(t, str.SendDatagram([]byte("foo")))
|
|
}
|
|
|
|
func TestStateTrackingStreamWrite(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
qstr := mockquic.NewMockStream(mockCtrl)
|
|
qstr.EXPECT().StreamID().AnyTimes().Return(quic.StreamID(1337))
|
|
qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
|
|
|
|
var clearer mockStreamClearer
|
|
str := newStateTrackingStream(qstr, &clearer, func(b []byte) error { return nil })
|
|
|
|
qstr.EXPECT().Write([]byte("foo")).Return(3, nil)
|
|
qstr.EXPECT().Write([]byte("baz")).Return(0, os.ErrDeadlineExceeded)
|
|
qstr.EXPECT().Write([]byte("bar")).Return(0, assert.AnError)
|
|
|
|
_, err := str.Write([]byte("foo"))
|
|
require.NoError(t, err)
|
|
require.Nil(t, clearer.cleared)
|
|
checkDatagramReceive(t, str)
|
|
checkDatagramSend(t, str)
|
|
|
|
// deadline errors are ignored
|
|
_, err = str.Write([]byte("baz"))
|
|
require.ErrorIs(t, err, os.ErrDeadlineExceeded)
|
|
require.Nil(t, clearer.cleared)
|
|
checkDatagramReceive(t, str)
|
|
checkDatagramSend(t, str)
|
|
|
|
_, err = str.Write([]byte("bar"))
|
|
require.ErrorIs(t, err, assert.AnError)
|
|
require.Nil(t, clearer.cleared)
|
|
checkDatagramReceive(t, str)
|
|
require.ErrorIs(t, str.SendDatagram([]byte("test")), assert.AnError)
|
|
}
|
|
|
|
func TestStateTrackingStreamCancelRead(t *testing.T) {
|
|
const streamID quic.StreamID = 42
|
|
mockCtrl := gomock.NewController(t)
|
|
qstr := mockquic.NewMockStream(mockCtrl)
|
|
qstr.EXPECT().StreamID().AnyTimes().Return(streamID)
|
|
qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
|
|
|
|
var clearer mockStreamClearer
|
|
str := newStateTrackingStream(qstr, &clearer, func(b []byte) error { return nil })
|
|
|
|
buf := bytes.NewBuffer([]byte("foobar"))
|
|
qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
|
qstr.EXPECT().CancelRead(quic.StreamErrorCode(1337))
|
|
_, err := str.Read(make([]byte, 3))
|
|
require.NoError(t, err)
|
|
require.Nil(t, clearer.cleared)
|
|
checkDatagramReceive(t, str)
|
|
checkDatagramSend(t, str)
|
|
|
|
str.CancelRead(1337)
|
|
require.Nil(t, clearer.cleared)
|
|
_, err = str.ReceiveDatagram(canceledCtx())
|
|
require.ErrorIs(t, err, &quic.StreamError{StreamID: streamID, ErrorCode: 1337})
|
|
checkDatagramSend(t, str)
|
|
}
|
|
|
|
func TestStateTrackingStreamCancelWrite(t *testing.T) {
|
|
const streamID quic.StreamID = 1234
|
|
mockCtrl := gomock.NewController(t)
|
|
qstr := mockquic.NewMockStream(mockCtrl)
|
|
qstr.EXPECT().StreamID().AnyTimes().Return(streamID)
|
|
qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
|
|
|
|
var clearer mockStreamClearer
|
|
str := newStateTrackingStream(qstr, &clearer, func(b []byte) error { return nil })
|
|
|
|
qstr.EXPECT().Write(gomock.Any())
|
|
qstr.EXPECT().CancelWrite(quic.StreamErrorCode(1337))
|
|
_, err := str.Write([]byte("foobar"))
|
|
require.NoError(t, err)
|
|
require.Nil(t, clearer.cleared)
|
|
checkDatagramReceive(t, str)
|
|
checkDatagramSend(t, str)
|
|
|
|
str.CancelWrite(1337)
|
|
require.Nil(t, clearer.cleared)
|
|
checkDatagramReceive(t, str)
|
|
require.ErrorIs(t, str.SendDatagram([]byte("test")), &quic.StreamError{StreamID: streamID, ErrorCode: 1337})
|
|
}
|
|
|
|
func TestStateTrackingStreamContext(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
qstr := mockquic.NewMockStream(mockCtrl)
|
|
qstr.EXPECT().StreamID().AnyTimes()
|
|
ctx, cancel := context.WithCancelCause(context.Background())
|
|
qstr.EXPECT().Context().Return(ctx).AnyTimes()
|
|
|
|
var clearer mockStreamClearer
|
|
str := newStateTrackingStream(qstr, &clearer, func(b []byte) error { return nil })
|
|
|
|
require.Nil(t, clearer.cleared)
|
|
checkDatagramReceive(t, str)
|
|
checkDatagramSend(t, str)
|
|
|
|
cancel(assert.AnError)
|
|
require.Eventually(t, func() bool {
|
|
err := str.SendDatagram([]byte("test"))
|
|
if err == nil {
|
|
return false
|
|
}
|
|
require.ErrorIs(t, err, assert.AnError)
|
|
return true
|
|
}, time.Second, scaleDuration(5*time.Millisecond))
|
|
|
|
checkDatagramReceive(t, str)
|
|
require.Nil(t, clearer.cleared)
|
|
}
|
|
|
|
func TestStateTrackingStreamReceiveThenSend(t *testing.T) {
|
|
streamID := quic.StreamID(1234)
|
|
mockCtrl := gomock.NewController(t)
|
|
qstr := mockquic.NewMockStream(mockCtrl)
|
|
qstr.EXPECT().StreamID().AnyTimes().Return(streamID)
|
|
qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
|
|
|
|
var clearer mockStreamClearer
|
|
str := newStateTrackingStream(qstr, &clearer, func(b []byte) error { return nil })
|
|
|
|
buf := bytes.NewBuffer([]byte("foobar"))
|
|
qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
|
_, err := io.ReadAll(str)
|
|
require.NoError(t, err)
|
|
|
|
require.Nil(t, clearer.cleared)
|
|
_, err = str.ReceiveDatagram(canceledCtx())
|
|
require.ErrorIs(t, err, io.EOF)
|
|
|
|
testErr := errors.New("test error")
|
|
qstr.EXPECT().Write([]byte("bar")).Return(0, testErr)
|
|
|
|
_, err = str.Write([]byte("bar"))
|
|
require.ErrorIs(t, err, testErr)
|
|
require.ErrorIs(t, str.SendDatagram([]byte("test")), testErr)
|
|
|
|
require.Equal(t, &streamID, clearer.cleared)
|
|
}
|
|
|
|
func TestStateTrackingStreamSendThenReceive(t *testing.T) {
|
|
streamID := quic.StreamID(1337)
|
|
mockCtrl := gomock.NewController(t)
|
|
qstr := mockquic.NewMockStream(mockCtrl)
|
|
qstr.EXPECT().StreamID().AnyTimes().Return(streamID)
|
|
qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
|
|
|
|
var clearer mockStreamClearer
|
|
str := newStateTrackingStream(qstr, &clearer, func(b []byte) error { return nil })
|
|
|
|
testErr := errors.New("test error")
|
|
qstr.EXPECT().Write([]byte("bar")).Return(0, testErr)
|
|
|
|
_, err := str.Write([]byte("bar"))
|
|
require.ErrorIs(t, err, testErr)
|
|
require.Nil(t, clearer.cleared)
|
|
require.ErrorIs(t, str.SendDatagram([]byte("test")), testErr)
|
|
|
|
buf := bytes.NewBuffer([]byte("foobar"))
|
|
qstr.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes()
|
|
|
|
_, err = io.ReadAll(str)
|
|
require.NoError(t, err)
|
|
_, err = str.ReceiveDatagram(canceledCtx())
|
|
require.ErrorIs(t, err, io.EOF)
|
|
|
|
require.Equal(t, &streamID, clearer.cleared)
|
|
}
|
|
|
|
func TestDatagramReceiving(t *testing.T) {
|
|
mockCtrl := gomock.NewController(t)
|
|
qstr := mockquic.NewMockStream(mockCtrl)
|
|
qstr.EXPECT().StreamID().AnyTimes().Return(quic.StreamID(1337))
|
|
qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
|
|
|
|
str := newStateTrackingStream(qstr, nil, func(b []byte) error { return nil })
|
|
type result struct {
|
|
data []byte
|
|
err error
|
|
}
|
|
|
|
// Receive blocks until a datagram is received
|
|
resultChan := make(chan result)
|
|
go func() {
|
|
defer close(resultChan)
|
|
data, err := str.ReceiveDatagram(context.Background())
|
|
resultChan <- result{data: data, err: err}
|
|
}()
|
|
|
|
select {
|
|
case <-time.After(scaleDuration(10 * time.Millisecond)):
|
|
case <-resultChan:
|
|
t.Fatal("should not have received a datagram")
|
|
}
|
|
str.enqueueDatagram([]byte("foobar"))
|
|
|
|
select {
|
|
case res := <-resultChan:
|
|
require.NoError(t, res.err)
|
|
require.Equal(t, []byte("foobar"), res.data)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("should have received a datagram")
|
|
}
|
|
|
|
// up to 32 datagrams can be queued
|
|
for i := range streamDatagramQueueLen + 1 {
|
|
str.enqueueDatagram([]byte{uint8(i)})
|
|
}
|
|
for i := range streamDatagramQueueLen {
|
|
data, err := str.ReceiveDatagram(context.Background())
|
|
require.NoError(t, err)
|
|
require.Equal(t, []byte{uint8(i)}, data)
|
|
}
|
|
|
|
// Receive respects the context
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
_, err := str.ReceiveDatagram(ctx)
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
}
|
|
|
|
func TestDatagramSending(t *testing.T) {
|
|
var sendQueue [][]byte
|
|
errors := []error{nil, nil, assert.AnError}
|
|
mockCtrl := gomock.NewController(t)
|
|
qstr := mockquic.NewMockStream(mockCtrl)
|
|
qstr.EXPECT().StreamID().AnyTimes().Return(quic.StreamID(1337))
|
|
qstr.EXPECT().Context().Return(context.Background()).AnyTimes()
|
|
|
|
str := newStateTrackingStream(qstr, nil, func(b []byte) error {
|
|
sendQueue = append(sendQueue, b)
|
|
err := errors[0]
|
|
errors = errors[1:]
|
|
return err
|
|
})
|
|
require.NoError(t, str.SendDatagram([]byte("foo")))
|
|
require.NoError(t, str.SendDatagram([]byte("bar")))
|
|
require.ErrorIs(t, str.SendDatagram([]byte("baz")), assert.AnError)
|
|
require.Equal(t, [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, sendQueue)
|
|
|
|
str.closeSend(net.ErrClosed)
|
|
require.ErrorIs(t, str.SendDatagram([]byte("foobar")), net.ErrClosed)
|
|
}
|