diff --git a/session_test.go b/session_test.go index 227be291b..6dda633fc 100644 --- a/session_test.go +++ b/session_test.go @@ -17,20 +17,29 @@ import ( "github.com/lucas-clemente/quic-go/utils" ) -type mockConnection struct{} +type mockConnection struct { + written [][]byte +} + +func (m *mockConnection) write(p []byte) error { + m.written = append(m.written, p) + return nil +} -func (*mockConnection) write(p []byte) error { return nil } func (*mockConnection) setCurrentRemoteAddr(addr interface{}) {} var _ = Describe("Session", func() { var ( session *Session callbackCalled bool + conn *mockConnection ) BeforeEach(func() { + conn = &mockConnection{} callbackCalled = false session = &Session{ + conn: conn, streams: make(map[protocol.StreamID]*stream), streamCallback: func(*Session, utils.Stream) { callbackCalled = true }, } @@ -178,7 +187,7 @@ var _ = Describe("Session", func() { signer, err := crypto.NewRSASigner(path+"cert.der", path+"key.der") Expect(err).ToNot(HaveOccurred()) scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) - session = NewSession(&mockConnection{}, 0, 0, scfg, nil).(*Session) + session = NewSession(conn, 0, 0, scfg, nil).(*Session) go session.Run() Expect(runtime.NumGoroutine()).To(Equal(nGoRoutinesBefore + 2)) }) @@ -204,4 +213,35 @@ var _ = Describe("Session", func() { Expect(err).To(Equal(testErr)) }) }) + + Context("sending packets", func() { + BeforeEach(func() { + path := os.Getenv("GOPATH") + "/src/github.com/lucas-clemente/quic-go/example/" + signer, err := crypto.NewRSASigner(path+"cert.der", path+"key.der") + Expect(err).ToNot(HaveOccurred()) + scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) + session = NewSession(conn, 0, 0, scfg, nil).(*Session) + }) + + It("sends ack frames", func() { + session.receivedPacketHandler.ReceivedPacket(1, true) + err := session.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(conn.written).To(HaveLen(1)) + Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x4c, 0x2, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0}))) + }) + + It("sends queued stream frames", func() { + session.QueueStreamFrame(&frames.StreamFrame{ + StreamID: 1, + Data: []byte("foobar"), + }) + session.receivedPacketHandler.ReceivedPacket(1, true) + err := session.sendPacket() + Expect(err).NotTo(HaveOccurred()) + Expect(conn.written).To(HaveLen(1)) + Expect(conn.written[0]).To(ContainSubstring(string([]byte{0x4c, 0x2, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0}))) + Expect(conn.written[0]).To(ContainSubstring(string("foobar"))) + }) + }) })