diff --git a/http3/conn.go b/http3/conn.go index 591afbb2..a2da426d 100644 --- a/http3/conn.go +++ b/http3/conn.go @@ -40,7 +40,11 @@ func newConnection( } func (c *connection) HandleUnidirectionalStreams() { - var rcvdControlStream atomic.Bool + var ( + rcvdControlStr atomic.Bool + rcvdQPACKEncoderStr atomic.Bool + rcvdQPACKDecoderStr atomic.Bool + ) for { str, err := c.quicConn.AcceptUniStream(context.Background()) @@ -61,9 +65,17 @@ func (c *connection) HandleUnidirectionalStreams() { // We're only interested in the control stream here. switch streamType { case streamTypeControlStream: - case streamTypeQPACKEncoderStream, streamTypeQPACKDecoderStream: + case streamTypeQPACKEncoderStream: + if isFirst := rcvdQPACKEncoderStr.CompareAndSwap(false, true); !isFirst { + c.quicConn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK encoder stream") + } + // Our QPACK implementation doesn't use the dynamic table yet. + return + case streamTypeQPACKDecoderStream: + if isFirst := rcvdQPACKDecoderStr.CompareAndSwap(false, true); !isFirst { + c.quicConn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK decoder stream") + } // Our QPACK implementation doesn't use the dynamic table yet. - // TODO: check that only one stream of each type is opened. return case streamTypePushStream: switch c.perspective { @@ -83,7 +95,7 @@ func (c *connection) HandleUnidirectionalStreams() { return } // Only a single control stream is allowed. - if isFirstControlStr := rcvdControlStream.CompareAndSwap(false, true); !isFirstControlStr { + if isFirstControlStr := rcvdControlStr.CompareAndSwap(false, true); !isFirstControlStr { c.quicConn.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream") return } diff --git a/http3/conn_test.go b/http3/conn_test.go index bd990582..643236cb 100644 --- a/http3/conn_test.go +++ b/http3/conn_test.go @@ -124,6 +124,41 @@ var _ = Describe("Connection", func() { }() Eventually(done).Should(BeClosed()) }) + + It(fmt.Sprintf("rejects duplicate QPACK %s streams", name), func() { + qconn := mockquic.NewMockEarlyConnection(mockCtrl) + conn := newConnection( + qconn, + false, + nil, + protocol.PerspectiveClient, + utils.DefaultLogger, + ) + buf := bytes.NewBuffer(quicvarint.Append(nil, streamType)) + str1 := mockquic.NewMockStream(mockCtrl) + str1.EXPECT().Read(gomock.Any()).DoAndReturn(buf.Read).AnyTimes() + buf2 := bytes.NewBuffer(quicvarint.Append(nil, streamType)) + str2 := mockquic.NewMockStream(mockCtrl) + str2.EXPECT().Read(gomock.Any()).DoAndReturn(buf2.Read).AnyTimes() + qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(str1, nil) + qconn.EXPECT().AcceptUniStream(gomock.Any()).Return(str2, nil) + testDone := make(chan struct{}) + qconn.EXPECT().AcceptUniStream(gomock.Any()).DoAndReturn(func(context.Context) (quic.ReceiveStream, error) { + <-testDone + return nil, errors.New("test done") + }) + qconn.EXPECT().CloseWithError(qerr.ApplicationErrorCode(ErrCodeStreamCreationError), gomock.Any()).Do(func(qerr.ApplicationErrorCode, string) error { + close(testDone) + return nil + }) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + conn.HandleUnidirectionalStreams() + }() + Eventually(done).Should(BeClosed()) + }) } It("resets streams other than the control stream and the QPACK streams", func() {