diff --git a/handshake/crypto_setup.go b/handshake/crypto_setup.go index a28c80e4d..78856ac49 100644 --- a/handshake/crypto_setup.go +++ b/handshake/crypto_setup.go @@ -3,7 +3,7 @@ package handshake import ( "bytes" "crypto/rand" - "fmt" + "errors" "io" "sync" @@ -52,19 +52,15 @@ func NewCryptoSetup(connID protocol.ConnectionID, version protocol.VersionNumber } // HandleCryptoStream reads and writes messages on the crypto stream -func (h *CryptoSetup) HandleCryptoStream() { - // TODO: Fix error handling - +func (h *CryptoSetup) HandleCryptoStream() error { for { cachingReader := utils.NewCachingReader(h.cryptoStream) messageTag, cryptoData, err := ParseHandshakeMessage(cachingReader) if err != nil { - fmt.Printf("error in crypto stream (TODO: handle): %s\n", err.Error()) - return + return err } if messageTag != TagCHLO { - fmt.Printf("error in crypto stream (TODO: handle): %s\n", "Session: expected CHLO") - return + return errors.New("CryptoSetup: expected CHLO") } chloData := cachingReader.Get() @@ -73,27 +69,23 @@ func (h *CryptoSetup) HandleCryptoStream() { // We have a CHLO with a proper server config ID, do a 0-RTT handshake reply, err = h.handleCHLO(chloData, cryptoData) if err != nil { - fmt.Printf("error in crypto stream (TODO: handle): %s\n", err.Error()) - return + return err } _, err = h.cryptoStream.Write(reply) if err != nil { - fmt.Printf("error in crypto stream (TODO: handle): %s\n", err.Error()) - return + return err } - return + return nil } // We have an inchoate or non-matching CHLO, we now send a rejection reply, err = h.handleInchoateCHLO(chloData) if err != nil { - fmt.Printf("error in crypto stream (TODO: handle): %s\n", err.Error()) - return + return err } _, err = h.cryptoStream.Write(reply) if err != nil { - fmt.Printf("error in crypto stream (TODO: handle): %s\n", err.Error()) - return + return err } } } diff --git a/session.go b/session.go index bb810cd39..50b6900f5 100644 --- a/session.go +++ b/session.go @@ -45,6 +45,7 @@ type Session struct { receivedPackets chan receivedPacket closeChan chan struct{} + closed bool // Used to calculate the next packet number from the truncated wire representation lastRcvdPacketNumber protocol.PacketNumber @@ -64,7 +65,12 @@ func NewSession(conn connection, v protocol.VersionNumber, connectionID protocol cryptoStream, _ := session.NewStream(1) cryptoSetup := handshake.NewCryptoSetup(connectionID, v, sCfg, cryptoStream) - go cryptoSetup.HandleCryptoStream() + + go func() { + if err := cryptoSetup.HandleCryptoStream(); err != nil { + session.Close(err) + } + }() session.packer = &packetPacker{aead: cryptoSetup, connectionID: connectionID} session.unpacker = &packetUnpacker{aead: cryptoSetup} @@ -196,6 +202,10 @@ func (s *Session) handleRstStreamFrame(frame *frames.RstStreamFrame) error { // Close the connection by sending a ConnectionClose frame func (s *Session) Close(e error) error { + if s.closed { + return nil + } + s.closed = true s.closeChan <- struct{}{} if e == nil { e = protocol.NewQuicError(errorcodes.QUIC_PEER_GOING_AWAY, "peer going away") diff --git a/session_test.go b/session_test.go index 6dda633fc..b108a15b7 100644 --- a/session_test.go +++ b/session_test.go @@ -244,4 +244,23 @@ var _ = Describe("Session", func() { Expect(conn.written[0]).To(ContainSubstring(string("foobar"))) }) }) + + It("closes when crypto stream errors", func() { + path := os.Getenv("GOPATH") + "/src/github.com/lucas-clemente/quic-go/example/" + signer, err := crypto.NewRSASigner(path+"cert.der", path+"key.der") + Expect(err).ToNot(HaveOccurred()) + scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) + session = NewSession(conn, 0, 0, scfg, nil).(*Session) + s, err := session.NewStream(3) + Expect(err).NotTo(HaveOccurred()) + err = session.handleStreamFrame(&frames.StreamFrame{ + StreamID: 1, + Data: []byte("4242\x00\x00\x00\x00"), + }) + Expect(err).NotTo(HaveOccurred()) + time.Sleep(time.Millisecond) + Expect(session.closed).To(BeTrue()) + _, err = s.Write([]byte{}) + Expect(err).To(MatchError("CryptoSetup: expected CHLO")) + }) })