diff --git a/connection.go b/connection.go index 9f53df564..9c8821d91 100644 --- a/connection.go +++ b/connection.go @@ -272,8 +272,8 @@ var newConnection = func( s.queueControlFrame, connIDGenerator, ) - s.preSetup() s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) + s.preSetup() s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler( 0, getMaxPacketSize(s.conn.RemoteAddr()), @@ -381,8 +381,8 @@ var newClientConnection = func( s.queueControlFrame, connIDGenerator, ) - s.preSetup() s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID)) + s.preSetup() s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler( initialPacketNumber, getMaxPacketSize(s.conn.RemoteAddr()), @@ -471,6 +471,7 @@ func (s *connection) preSetup() { ) s.earlyConnReadyChan = make(chan struct{}) s.streamsMap = newStreamsMap( + s.ctx, s, s.newFlowController, uint64(s.config.MaxIncomingStreams), diff --git a/integrationtests/self/http_test.go b/integrationtests/self/http_test.go index 05e065cf4..c8b1d2829 100644 --- a/integrationtests/self/http_test.go +++ b/integrationtests/self/http_test.go @@ -536,6 +536,7 @@ var _ = Describe("HTTP tests", func() { It("sets conn context", func() { type ctxKey int + var tracingID quic.ConnectionTracingID server.ConnContext = func(ctx context.Context, c quic.Connection) context.Context { serv, ok := ctx.Value(http3.ServerContextKey).(*http3.Server) Expect(ok).To(BeTrue()) @@ -543,6 +544,7 @@ var _ = Describe("HTTP tests", func() { ctx = context.WithValue(ctx, ctxKey(0), "Hello") ctx = context.WithValue(ctx, ctxKey(1), c) + tracingID = c.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) return ctx } mux.HandleFunc("/conn-context", func(w http.ResponseWriter, r *http.Request) { @@ -558,6 +560,10 @@ var _ = Describe("HTTP tests", func() { serv, ok := r.Context().Value(http3.ServerContextKey).(*http3.Server) Expect(ok).To(BeTrue()) Expect(serv).To(Equal(server)) + + id, ok := r.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID) + Expect(ok).To(BeTrue()) + Expect(id).To(Equal(tracingID)) }) resp, err := client.Get(fmt.Sprintf("https://localhost:%d/conn-context", port)) diff --git a/send_stream.go b/send_stream.go index e1ce3e677..7f9b0eb40 100644 --- a/send_stream.go +++ b/send_stream.go @@ -60,6 +60,7 @@ var ( ) func newSendStream( + ctx context.Context, streamID protocol.StreamID, sender streamSender, flowController flowcontrol.StreamFlowController, @@ -71,7 +72,7 @@ func newSendStream( writeChan: make(chan struct{}, 1), writeOnce: make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write } - s.ctx, s.ctxCancel = context.WithCancelCause(context.Background()) + s.ctx, s.ctxCancel = context.WithCancelCause(ctx) return s } diff --git a/send_stream_test.go b/send_stream_test.go index 507133080..1f9308baa 100644 --- a/send_stream_test.go +++ b/send_stream_test.go @@ -35,7 +35,7 @@ var _ = Describe("Send Stream", func() { BeforeEach(func() { mockSender = NewMockStreamSender(mockCtrl) mockFC = mocks.NewMockStreamFlowController(mockCtrl) - str = newSendStream(streamID, mockSender, mockFC) + str = newSendStream(context.Background(), streamID, mockSender, mockFC) timeout := scaleDuration(250 * time.Millisecond) strWithTimeout = gbytes.TimeoutWriter(str, timeout) diff --git a/stream.go b/stream.go index ce4374d60..d6c809ad7 100644 --- a/stream.go +++ b/stream.go @@ -1,6 +1,7 @@ package quic import ( + "context" "net" "os" "sync" @@ -85,7 +86,9 @@ type stream struct { var _ Stream = &stream{} // newStream creates a new Stream -func newStream(streamID protocol.StreamID, +func newStream( + ctx context.Context, + streamID protocol.StreamID, sender streamSender, flowController flowcontrol.StreamFlowController, ) *stream { @@ -99,7 +102,7 @@ func newStream(streamID protocol.StreamID, s.completedMutex.Unlock() }, } - s.sendStream = *newSendStream(streamID, senderForSendStream, flowController) + s.sendStream = *newSendStream(ctx, streamID, senderForSendStream, flowController) senderForReceiveStream := &uniStreamSender{ streamSender: sender, onStreamCompletedImpl: func() { diff --git a/stream_test.go b/stream_test.go index e1e3804f9..5d0e64b5a 100644 --- a/stream_test.go +++ b/stream_test.go @@ -1,6 +1,7 @@ package quic import ( + "context" "errors" "io" "os" @@ -41,7 +42,7 @@ var _ = Describe("Stream", func() { BeforeEach(func() { mockSender = NewMockStreamSender(mockCtrl) mockFC = mocks.NewMockStreamFlowController(mockCtrl) - str = newStream(streamID, mockSender, mockFC) + str = newStream(context.Background(), streamID, mockSender, mockFC) timeout := scaleDuration(250 * time.Millisecond) strWithTimeout = struct { diff --git a/streams_map.go b/streams_map.go index b1a80eb36..0ba23b254 100644 --- a/streams_map.go +++ b/streams_map.go @@ -45,6 +45,7 @@ func (streamOpenErr) Timeout() bool { return false } var errTooManyOpenStreams = errors.New("too many open streams") type streamsMap struct { + ctx context.Context // not used for cancellations, but carries the values associated with the connection perspective protocol.Perspective maxIncomingBidiStreams uint64 @@ -64,6 +65,7 @@ type streamsMap struct { var _ streamManager = &streamsMap{} func newStreamsMap( + ctx context.Context, sender streamSender, newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController, maxIncomingBidiStreams uint64, @@ -71,6 +73,7 @@ func newStreamsMap( perspective protocol.Perspective, ) streamManager { m := &streamsMap{ + ctx: ctx, perspective: perspective, newFlowController: newFlowController, maxIncomingBidiStreams: maxIncomingBidiStreams, @@ -86,7 +89,7 @@ func (m *streamsMap) initMaps() { protocol.StreamTypeBidi, func(num protocol.StreamNum) streamI { id := num.StreamID(protocol.StreamTypeBidi, m.perspective) - return newStream(id, m.sender, m.newFlowController(id)) + return newStream(m.ctx, id, m.sender, m.newFlowController(id)) }, m.sender.queueControlFrame, ) @@ -94,7 +97,7 @@ func (m *streamsMap) initMaps() { protocol.StreamTypeBidi, func(num protocol.StreamNum) streamI { id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite()) - return newStream(id, m.sender, m.newFlowController(id)) + return newStream(m.ctx, id, m.sender, m.newFlowController(id)) }, m.maxIncomingBidiStreams, m.sender.queueControlFrame, @@ -103,7 +106,7 @@ func (m *streamsMap) initMaps() { protocol.StreamTypeUni, func(num protocol.StreamNum) sendStreamI { id := num.StreamID(protocol.StreamTypeUni, m.perspective) - return newSendStream(id, m.sender, m.newFlowController(id)) + return newSendStream(m.ctx, id, m.sender, m.newFlowController(id)) }, m.sender.queueControlFrame, ) diff --git a/streams_map_test.go b/streams_map_test.go index 77ee4aa82..b4a2d91ff 100644 --- a/streams_map_test.go +++ b/streams_map_test.go @@ -87,7 +87,7 @@ var _ = Describe("Streams Map", func() { BeforeEach(func() { mockSender = NewMockStreamSender(mockCtrl) - m = newStreamsMap(mockSender, newFlowController, MaxBidiStreamNum, MaxUniStreamNum, perspective).(*streamsMap) + m = newStreamsMap(context.Background(), mockSender, newFlowController, MaxBidiStreamNum, MaxUniStreamNum, perspective).(*streamsMap) }) Context("opening", func() {