add a callback to client that is called after the version is negotiated

This commit is contained in:
Marten Seemann
2016-12-14 18:35:33 +07:00
parent 2377b3a111
commit dc05de3312
3 changed files with 37 additions and 10 deletions

View File

@@ -13,11 +13,20 @@ import (
)
var _ = Describe("Client", func() {
var client *Client
var session *mockSession
var (
client *Client
session *mockSession
versionNegotiateCallbackCalled bool
)
BeforeEach(func() {
client = &Client{}
versionNegotiateCallbackCalled = false
client = &Client{
versionNegotiateCallback: func() error {
versionNegotiateCallbackCalled = true
return nil
},
}
session = &mockSession{connectionID: 0x1337}
client.connectionID = 0x1337
client.session = session
@@ -162,6 +171,7 @@ var _ = Describe("Client", func() {
err = client.handlePacket(b.Bytes())
Expect(err).ToNot(HaveOccurred())
Expect(client.versionNegotiated).To(BeTrue())
Expect(versionNegotiateCallbackCalled).To(BeTrue())
})
It("changes the version after receiving a version negotiation packet", func() {
@@ -172,6 +182,7 @@ var _ = Describe("Client", func() {
err := client.handlePacket(getVersionNegotiation([]protocol.VersionNumber{newVersion}))
Expect(client.version).To(Equal(newVersion))
Expect(client.versionNegotiated).To(BeTrue())
Expect(versionNegotiateCallbackCalled).To(BeTrue())
// it swapped the sessions
Expect(client.session).ToNot(Equal(session))
Expect(err).ToNot(HaveOccurred())
@@ -195,6 +206,7 @@ var _ = Describe("Client", func() {
Expect(err).ToNot(HaveOccurred())
Expect(client.versionNegotiated).To(BeTrue())
Expect(session.packetCount).To(BeZero())
Expect(versionNegotiateCallbackCalled).To(BeFalse())
})
It("errors if the server should have accepted the offered version", func() {