From 7a97f34fac68534437534aa8df3324e0ec8c77dc Mon Sep 17 00:00:00 2001 From: Lucas Clemente Date: Wed, 18 May 2016 18:29:42 +0200 Subject: [PATCH] =?UTF-8?q?don't=20panic=20=F0=9F=A4=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fixes #93 --- ackhandler/sent_packet_handler.go | 2 +- crypto/cert_compression.go | 3 ++- crypto/curve_25519.go | 6 +++--- crypto/curve_25519_test.go | 6 ++++-- frames/ack_frame.go | 2 +- h2quic/response_writer.go | 2 +- handshake/crypto_setup.go | 13 ++++++++----- handshake/crypto_setup_test.go | 9 ++++++--- handshake/server_config.go | 6 +++--- handshake/server_config_test.go | 7 +++++-- packet_packer.go | 7 ++++--- server.go | 18 ++++++++++++++---- server_test.go | 4 ++-- session.go | 14 +++++++++----- session_test.go | 11 ++++++++--- 15 files changed, 71 insertions(+), 39 deletions(-) diff --git a/ackhandler/sent_packet_handler.go b/ackhandler/sent_packet_handler.go index 9f0ef9b5..663d3d0e 100644 --- a/ackhandler/sent_packet_handler.go +++ b/ackhandler/sent_packet_handler.go @@ -103,7 +103,7 @@ func (h *sentPacketHandler) SentPacket(packet *Packet) error { } packet.sendTime = time.Now() if packet.Length == 0 { - panic("SentPacketHandler: packet cannot be empty") + return errors.New("SentPacketHandler: packet cannot be empty") } h.bytesInFlight += packet.Length diff --git a/crypto/cert_compression.go b/crypto/cert_compression.go index a6dc8f81..8fdb2579 100644 --- a/crypto/cert_compression.go +++ b/crypto/cert_compression.go @@ -6,6 +6,7 @@ import ( "compress/zlib" "encoding/binary" "errors" + "fmt" "hash/fnv" "github.com/lucas-clemente/quic-go/utils" @@ -63,7 +64,7 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by if totalUncompressedLen > 0 { gz, err := zlib.NewWriterLevelDict(res, flate.BestCompression, buildZlibDictForEntries(entries, chain)) if err != nil { - panic(err) + return nil, fmt.Errorf("cert compression failed: %s", err.Error()) } utils.WriteUint32(res, uint32(totalUncompressedLen)) diff --git a/crypto/curve_25519.go b/crypto/curve_25519.go index 349fa790..4813a261 100644 --- a/crypto/curve_25519.go +++ b/crypto/curve_25519.go @@ -17,17 +17,17 @@ type curve25519KEX struct { var _ KeyExchange = &curve25519KEX{} // NewCurve25519KEX creates a new KeyExchange using Curve25519, see https://cr.yp.to/ecdh.html -func NewCurve25519KEX() KeyExchange { +func NewCurve25519KEX() (KeyExchange, error) { c := &curve25519KEX{} if _, err := io.ReadFull(rand.Reader, c.secret[:]); err != nil { - panic("Curve25519: could not create private key") + return nil, errors.New("Curve25519: could not create private key") } // See https://cr.yp.to/ecdh.html c.secret[0] &= 248 c.secret[31] &= 127 c.secret[31] |= 64 curve25519.ScalarBaseMult(&c.public, &c.secret) - return c + return c, nil } func (c *curve25519KEX) PublicKey() []byte { diff --git a/crypto/curve_25519_test.go b/crypto/curve_25519_test.go index 04736cd2..f1455c8f 100644 --- a/crypto/curve_25519_test.go +++ b/crypto/curve_25519_test.go @@ -7,8 +7,10 @@ import ( var _ = Describe("ProofRsa", func() { It("works", func() { - a := NewCurve25519KEX() - b := NewCurve25519KEX() + a, err := NewCurve25519KEX() + Expect(err).ToNot(HaveOccurred()) + b, err := NewCurve25519KEX() + Expect(err).ToNot(HaveOccurred()) sA, err := a.CalculateSharedKey(b.PublicKey()) Expect(err).ToNot(HaveOccurred()) sB, err := b.CalculateSharedKey(a.PublicKey()) diff --git a/frames/ack_frame.go b/frames/ack_frame.go index 7b12ced0..785585c8 100644 --- a/frames/ack_frame.go +++ b/frames/ack_frame.go @@ -102,7 +102,7 @@ func (f *AckFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error } if rangeCounter != uint8(numRanges) { - panic("Inconsistent number of NACK ranges written.") + return errors.New("BUG: Inconsistent number of NACK ranges written") } // TODO: Remove once we drop support for <32 diff --git a/h2quic/response_writer.go b/h2quic/response_writer.go index 2b7800bb..fc0d8a45 100644 --- a/h2quic/response_writer.go +++ b/h2quic/response_writer.go @@ -52,7 +52,7 @@ func (w *responseWriter) WriteHeader(status int) { BlockFragment: headers.Bytes(), }) if err != nil { - panic(err) + utils.Errorf("could not write h2 header: %s", err.Error()) } } diff --git a/handshake/crypto_setup.go b/handshake/crypto_setup.go index b6f1e425..9a201cc2 100644 --- a/handshake/crypto_setup.go +++ b/handshake/crypto_setup.go @@ -16,7 +16,7 @@ import ( type KeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte) (crypto.AEAD, error) // KeyExchangeFunction is used to make a new KEX -type KeyExchangeFunction func() crypto.KeyExchange +type KeyExchangeFunction func() (crypto.KeyExchange, error) // The CryptoSetup handles all things crypto for the Session type CryptoSetup struct { @@ -44,10 +44,10 @@ type CryptoSetup struct { var _ crypto.AEAD = &CryptoSetup{} // NewCryptoSetup creates a new CryptoSetup instance -func NewCryptoSetup(connID protocol.ConnectionID, version protocol.VersionNumber, scfg *ServerConfig, cryptoStream utils.Stream, connectionParametersManager *ConnectionParametersManager, aeadChanged chan struct{}) *CryptoSetup { +func NewCryptoSetup(connID protocol.ConnectionID, version protocol.VersionNumber, scfg *ServerConfig, cryptoStream utils.Stream, connectionParametersManager *ConnectionParametersManager, aeadChanged chan struct{}) (*CryptoSetup, error) { nonce := make([]byte, 32) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { - panic(err) + return nil, err } return &CryptoSetup{ connID: connID, @@ -59,7 +59,7 @@ func NewCryptoSetup(connID protocol.ConnectionID, version protocol.VersionNumber cryptoStream: cryptoStream, connectionParametersManager: connectionParametersManager, aeadChanged: aeadChanged, - } + }, nil } // HandleCryptoStream reads and writes messages on the crypto stream @@ -219,7 +219,10 @@ func (h *CryptoSetup) handleCHLO(sni string, data []byte, cryptoData map[Tag][]b } // Generate a new curve instance to derive the forward secure key - ephermalKex := h.keyExchange() + ephermalKex, err := h.keyExchange() + if err != nil { + return nil, err + } ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS]) if err != nil { return nil, err diff --git a/handshake/crypto_setup_test.go b/handshake/crypto_setup_test.go index 6ae0ebd8..b8a3ec59 100644 --- a/handshake/crypto_setup_test.go +++ b/handshake/crypto_setup_test.go @@ -104,16 +104,19 @@ var _ = Describe("Crypto setup", func() { ) BeforeEach(func() { + var err error aeadChanged = make(chan struct{}, 1) stream = &mockStream{} kex = &mockKEX{} signer = &mockSigner{} - scfg = NewServerConfig(kex, signer) + scfg, err = NewServerConfig(kex, signer) + Expect(err).NotTo(HaveOccurred()) v := protocol.SupportedVersions[len(protocol.SupportedVersions)-1] cpm = NewConnectionParamatersManager() - cs = NewCryptoSetup(protocol.ConnectionID(42), v, scfg, stream, cpm, aeadChanged) + cs, err = NewCryptoSetup(protocol.ConnectionID(42), v, scfg, stream, cpm, aeadChanged) + Expect(err).NotTo(HaveOccurred()) cs.keyDerivation = mockKeyDerivation - cs.keyExchange = func() crypto.KeyExchange { return &mockKEX{ephermal: true} } + cs.keyExchange = func() (crypto.KeyExchange, error) { return &mockKEX{ephermal: true}, nil } }) It("has a nonce", func() { diff --git a/handshake/server_config.go b/handshake/server_config.go index 2bbd0773..38623a51 100644 --- a/handshake/server_config.go +++ b/handshake/server_config.go @@ -16,17 +16,17 @@ type ServerConfig struct { } // NewServerConfig creates a new server config -func NewServerConfig(kex crypto.KeyExchange, signer crypto.Signer) *ServerConfig { +func NewServerConfig(kex crypto.KeyExchange, signer crypto.Signer) (*ServerConfig, error) { id := make([]byte, 16) _, err := io.ReadFull(rand.Reader, id) if err != nil { - panic(err) + return nil, err } return &ServerConfig{ kex: kex, signer: signer, ID: id, - } + }, nil } // Get the server config binary representation diff --git a/handshake/server_config_test.go b/handshake/server_config_test.go index bce9f1c0..0e4c100e 100644 --- a/handshake/server_config_test.go +++ b/handshake/server_config_test.go @@ -16,8 +16,11 @@ var _ = Describe("ServerConfig", func() { ) BeforeEach(func() { - kex = crypto.NewCurve25519KEX() - scfg = NewServerConfig(kex, nil) + var err error + kex, err = crypto.NewCurve25519KEX() + Expect(err).NotTo(HaveOccurred()) + scfg, err = NewServerConfig(kex, nil) + Expect(err).NotTo(HaveOccurred()) }) It("gets the proper binary representation", func() { diff --git a/packet_packer.go b/packet_packer.go index acd3e724..c3e1065f 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "errors" "sync/atomic" "github.com/lucas-clemente/quic-go/ackhandler" @@ -101,7 +102,7 @@ func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, con raw.Write(ciphertext) if protocol.ByteCount(raw.Len()) > protocol.MaxPacketSize { - panic("internal inconsistency: packet too large") + return nil, errors.New("PacketPacker BUG: packet too large") } return &packedPacket{ @@ -148,7 +149,7 @@ func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFra } if payloadLength > maxFrameSize { - panic("internal inconsistency: packet payload too large") + return nil, errors.New("PacketPacker BUG: packet payload too large") } if !includeStreamFrames { @@ -167,7 +168,7 @@ func (p *packetPacker) composeNextPacket(stopWaitingFrame *frames.StopWaitingFra frame.DataLenPresent = true // set the dataLen by default. Remove them later if applicable if payloadLength > maxFrameSize { - panic("internal inconsistency: packet payload too large") + return nil, errors.New("PacketPacker BUG: packet payload too large") } // Does the frame fit into the remaining space? diff --git a/server.go b/server.go index 74f9689d..89e9b8bc 100644 --- a/server.go +++ b/server.go @@ -32,7 +32,7 @@ type Server struct { streamCallback StreamCallback - newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback closeCallback) packetHandler + newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback closeCallback) (packetHandler, error) } // NewServer makes a new server @@ -42,7 +42,14 @@ func NewServer(tlsConfig *tls.Config, cb StreamCallback) (*Server, error) { return nil, err } - scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) + kex, err := crypto.NewCurve25519KEX() + if err != nil { + return nil, err + } + scfg, err := handshake.NewServerConfig(kex, signer) + if err != nil { + return nil, err + } return &Server{ signer: signer, @@ -123,7 +130,7 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet if !ok { utils.Infof("Serving new connection: %x, version %d from %v", hdr.ConnectionID, hdr.VersionNumber, remoteAddr) - session = s.newSession( + session, err = s.newSession( &udpConn{conn: conn, currentAddr: remoteAddr}, hdr.VersionNumber, hdr.ConnectionID, @@ -131,6 +138,9 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet s.streamCallback, s.closeCallback, ) + if err != nil { + return err + } go session.run() s.sessionsMutex.Lock() s.sessions[hdr.ConnectionID] = session @@ -159,7 +169,7 @@ func composeVersionNegotiation(connectionID protocol.ConnectionID) []byte { } err := responsePublicHeader.WritePublicHeader(fullReply) if err != nil { - panic(err) // Should not happen ;) + utils.Errorf("error composing version negotiation packet: %s", err.Error()) } fullReply.Write(protocol.SupportedVersionsAsTags) return fullReply.Bytes() diff --git a/server_test.go b/server_test.go index fbade42d..df5fb329 100644 --- a/server_test.go +++ b/server_test.go @@ -25,10 +25,10 @@ func (s *mockSession) handlePacket(addr interface{}, hdr *publicHeader, data []b func (s *mockSession) run() { } -func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback closeCallback) packetHandler { +func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback closeCallback) (packetHandler, error) { return &mockSession{ connectionID: connectionID, - } + }, nil } var _ = Describe("Server", func() { diff --git a/session.go b/session.go index 658479e9..30eba1c8 100644 --- a/session.go +++ b/session.go @@ -82,7 +82,7 @@ type Session struct { } // newSession makes a new session -func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback closeCallback) packetHandler { +func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, streamCallback StreamCallback, closeCallback closeCallback) (packetHandler, error) { stopWaitingManager := ackhandler.NewStopWaitingManager() connectionParametersManager := handshake.NewConnectionParamatersManager() @@ -107,7 +107,11 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol } cryptoStream, _ := session.OpenStream(1) - session.cryptoSetup = handshake.NewCryptoSetup(connectionID, v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged) + var err error + session.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, v, sCfg, cryptoStream, session.connectionParametersManager, session.aeadChanged) + if err != nil { + return nil, err + } session.packer = &packetPacker{ aead: session.cryptoSetup, @@ -126,7 +130,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol protocol.DefaultMaxCongestionWindow, ) - return session + return session, err } // run the session main loop @@ -238,7 +242,7 @@ func (s *Session) handlePacketImpl(remoteAddr interface{}, hdr *publicHeader, da case *frames.PingFrame: utils.Debugf("\t<- %#v", frame) default: - panic("unexpected frame type") + return errors.New("Session BUG: unexpected frame type") } if err != nil { return err @@ -549,7 +553,7 @@ func (s *Session) sendConnectionClose(quicErr *qerr.QuicError) error { return err } if packet == nil { - panic("Session: internal inconsistency: expected packet not to be nil") + return errors.New("Session BUG: expected packet not to be nil") } return s.conn.write(packet.raw) } diff --git a/session_test.go b/session_test.go index 505fb43d..6e3b3ad6 100644 --- a/session_test.go +++ b/session_test.go @@ -94,15 +94,20 @@ var _ = Describe("Session", func() { signer, err := crypto.NewRSASigner(testdata.GetTLSConfig()) Expect(err).ToNot(HaveOccurred()) - scfg := handshake.NewServerConfig(crypto.NewCurve25519KEX(), signer) - session = newSession( + kex, err := crypto.NewCurve25519KEX() + Expect(err).NotTo(HaveOccurred()) + scfg, err := handshake.NewServerConfig(kex, signer) + Expect(err).NotTo(HaveOccurred()) + pSession, err := newSession( conn, 0, 0, scfg, func(*Session, utils.Stream) { streamCallbackCalled = true }, func(protocol.ConnectionID) { closeCallbackCalled = true }, - ).(*Session) + ) + Expect(err).NotTo(HaveOccurred()) + session = pSession.(*Session) Expect(session.streams).To(HaveLen(1)) // Crypto stream })