expose the connection tracing ID on the stream context (#4414)

This is especially interesting for HTTP servers: They can now learn
which connection a request was received on.
This commit is contained in:
Marten Seemann
2024-04-07 02:41:25 +12:00
committed by GitHub
parent 183d42a729
commit e310b80cf3
8 changed files with 26 additions and 11 deletions

View File

@@ -272,8 +272,8 @@ var newConnection = func(
s.queueControlFrame,
connIDGenerator,
)
s.preSetup()
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.preSetup()
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
0,
getMaxPacketSize(s.conn.RemoteAddr()),
@@ -381,8 +381,8 @@ var newClientConnection = func(
s.queueControlFrame,
connIDGenerator,
)
s.preSetup()
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.preSetup()
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
initialPacketNumber,
getMaxPacketSize(s.conn.RemoteAddr()),
@@ -471,6 +471,7 @@ func (s *connection) preSetup() {
)
s.earlyConnReadyChan = make(chan struct{})
s.streamsMap = newStreamsMap(
s.ctx,
s,
s.newFlowController,
uint64(s.config.MaxIncomingStreams),

View File

@@ -536,6 +536,7 @@ var _ = Describe("HTTP tests", func() {
It("sets conn context", func() {
type ctxKey int
var tracingID quic.ConnectionTracingID
server.ConnContext = func(ctx context.Context, c quic.Connection) context.Context {
serv, ok := ctx.Value(http3.ServerContextKey).(*http3.Server)
Expect(ok).To(BeTrue())
@@ -543,6 +544,7 @@ var _ = Describe("HTTP tests", func() {
ctx = context.WithValue(ctx, ctxKey(0), "Hello")
ctx = context.WithValue(ctx, ctxKey(1), c)
tracingID = c.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
return ctx
}
mux.HandleFunc("/conn-context", func(w http.ResponseWriter, r *http.Request) {
@@ -558,6 +560,10 @@ var _ = Describe("HTTP tests", func() {
serv, ok := r.Context().Value(http3.ServerContextKey).(*http3.Server)
Expect(ok).To(BeTrue())
Expect(serv).To(Equal(server))
id, ok := r.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
Expect(ok).To(BeTrue())
Expect(id).To(Equal(tracingID))
})
resp, err := client.Get(fmt.Sprintf("https://localhost:%d/conn-context", port))

View File

@@ -60,6 +60,7 @@ var (
)
func newSendStream(
ctx context.Context,
streamID protocol.StreamID,
sender streamSender,
flowController flowcontrol.StreamFlowController,
@@ -71,7 +72,7 @@ func newSendStream(
writeChan: make(chan struct{}, 1),
writeOnce: make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write
}
s.ctx, s.ctxCancel = context.WithCancelCause(context.Background())
s.ctx, s.ctxCancel = context.WithCancelCause(ctx)
return s
}

View File

@@ -35,7 +35,7 @@ var _ = Describe("Send Stream", func() {
BeforeEach(func() {
mockSender = NewMockStreamSender(mockCtrl)
mockFC = mocks.NewMockStreamFlowController(mockCtrl)
str = newSendStream(streamID, mockSender, mockFC)
str = newSendStream(context.Background(), streamID, mockSender, mockFC)
timeout := scaleDuration(250 * time.Millisecond)
strWithTimeout = gbytes.TimeoutWriter(str, timeout)

View File

@@ -1,6 +1,7 @@
package quic
import (
"context"
"net"
"os"
"sync"
@@ -85,7 +86,9 @@ type stream struct {
var _ Stream = &stream{}
// newStream creates a new Stream
func newStream(streamID protocol.StreamID,
func newStream(
ctx context.Context,
streamID protocol.StreamID,
sender streamSender,
flowController flowcontrol.StreamFlowController,
) *stream {
@@ -99,7 +102,7 @@ func newStream(streamID protocol.StreamID,
s.completedMutex.Unlock()
},
}
s.sendStream = *newSendStream(streamID, senderForSendStream, flowController)
s.sendStream = *newSendStream(ctx, streamID, senderForSendStream, flowController)
senderForReceiveStream := &uniStreamSender{
streamSender: sender,
onStreamCompletedImpl: func() {

View File

@@ -1,6 +1,7 @@
package quic
import (
"context"
"errors"
"io"
"os"
@@ -41,7 +42,7 @@ var _ = Describe("Stream", func() {
BeforeEach(func() {
mockSender = NewMockStreamSender(mockCtrl)
mockFC = mocks.NewMockStreamFlowController(mockCtrl)
str = newStream(streamID, mockSender, mockFC)
str = newStream(context.Background(), streamID, mockSender, mockFC)
timeout := scaleDuration(250 * time.Millisecond)
strWithTimeout = struct {

View File

@@ -45,6 +45,7 @@ func (streamOpenErr) Timeout() bool { return false }
var errTooManyOpenStreams = errors.New("too many open streams")
type streamsMap struct {
ctx context.Context // not used for cancellations, but carries the values associated with the connection
perspective protocol.Perspective
maxIncomingBidiStreams uint64
@@ -64,6 +65,7 @@ type streamsMap struct {
var _ streamManager = &streamsMap{}
func newStreamsMap(
ctx context.Context,
sender streamSender,
newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
maxIncomingBidiStreams uint64,
@@ -71,6 +73,7 @@ func newStreamsMap(
perspective protocol.Perspective,
) streamManager {
m := &streamsMap{
ctx: ctx,
perspective: perspective,
newFlowController: newFlowController,
maxIncomingBidiStreams: maxIncomingBidiStreams,
@@ -86,7 +89,7 @@ func (m *streamsMap) initMaps() {
protocol.StreamTypeBidi,
func(num protocol.StreamNum) streamI {
id := num.StreamID(protocol.StreamTypeBidi, m.perspective)
return newStream(id, m.sender, m.newFlowController(id))
return newStream(m.ctx, id, m.sender, m.newFlowController(id))
},
m.sender.queueControlFrame,
)
@@ -94,7 +97,7 @@ func (m *streamsMap) initMaps() {
protocol.StreamTypeBidi,
func(num protocol.StreamNum) streamI {
id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite())
return newStream(id, m.sender, m.newFlowController(id))
return newStream(m.ctx, id, m.sender, m.newFlowController(id))
},
m.maxIncomingBidiStreams,
m.sender.queueControlFrame,
@@ -103,7 +106,7 @@ func (m *streamsMap) initMaps() {
protocol.StreamTypeUni,
func(num protocol.StreamNum) sendStreamI {
id := num.StreamID(protocol.StreamTypeUni, m.perspective)
return newSendStream(id, m.sender, m.newFlowController(id))
return newSendStream(m.ctx, id, m.sender, m.newFlowController(id))
},
m.sender.queueControlFrame,
)

View File

@@ -87,7 +87,7 @@ var _ = Describe("Streams Map", func() {
BeforeEach(func() {
mockSender = NewMockStreamSender(mockCtrl)
m = newStreamsMap(mockSender, newFlowController, MaxBidiStreamNum, MaxUniStreamNum, perspective).(*streamsMap)
m = newStreamsMap(context.Background(), mockSender, newFlowController, MaxBidiStreamNum, MaxUniStreamNum, perspective).(*streamsMap)
})
Context("opening", func() {