update mint

This commit is contained in:
Marten Seemann
2018-08-12 15:26:40 +07:00
parent 40050f558d
commit e34256beed
13 changed files with 883 additions and 288 deletions

View File

@@ -18,6 +18,12 @@ library](https://golang.org/pkg/crypto/tls/), especially where TLS 1.3 aligns
with earlier TLS versions. However, unnecessary parts will be ruthlessly cut
off.
## DTLS Support
Mint has partial support for DTLS, but that support is not yet complete
and may still contain serious defects.
## Quickstart
Installation is the same as for any other Go package:

View File

@@ -58,7 +58,7 @@ type clientStateStart struct {
cookie []byte
firstClientHello *HandshakeMessage
helloRetryRequest *HandshakeMessage
hsCtx HandshakeContext
hsCtx *HandshakeContext
}
var _ HandshakeState = &clientStateStart{}
@@ -172,8 +172,10 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
}
ch.CipherSuites = compatibleSuites
// TODO(ekr@rtfm.com): Check that the ticket can be used for early
// data.
// Signal early data if we're going to do it
if len(state.Opts.EarlyData) > 0 {
if state.Config.AllowEarlyData && state.helloRetryRequest == nil {
state.Params.ClientSendingEarlyData = true
ed = &EarlyDataExtension{}
err = ch.Extensions.Add(ed)
@@ -255,9 +257,6 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash)
logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret)
clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret)
} else if len(state.Opts.EarlyData) > 0 {
logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK")
return nil, nil, AlertInternalError
} else {
clientHello, err = state.hsCtx.hOut.HandshakeMessageFromBody(ch)
if err != nil {
@@ -291,7 +290,6 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
if state.Params.ClientSendingEarlyData {
toSend = append(toSend, []HandshakeAction{
RekeyOut{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys},
SendEarlyData{},
}...)
}
@@ -302,7 +300,7 @@ type clientStateWaitSH struct {
Config *Config
Opts ConnectionOptions
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
OfferedDH map[NamedGroup][]byte
OfferedPSK PreSharedKey
PSK []byte
@@ -412,6 +410,11 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
body: h.Sum(nil),
}
state.hsCtx.receivedEndOfFlight()
// TODO(ekr@rtfm.com): Need to rekey with cleartext if we are on 0-RTT
// mode. In DTLS, we also need to bump the sequence number.
// This is a pre-existing defect in Mint. Issue #175.
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]")
return clientStateStart{
Config: state.Config,
@@ -420,7 +423,7 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
cookie: serverCookie.Cookie,
firstClientHello: firstClientHello,
helloRetryRequest: hm,
}, nil, AlertNoAlert
}, []HandshakeAction{ResetOut{1}}, AlertNoAlert
}
// This is SH.
@@ -515,7 +518,6 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret)
serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret)
logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]")
nextState := clientStateWaitEE{
Config: state.Config,
@@ -530,13 +532,20 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState,
toSend := []HandshakeAction{
RekeyIn{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys},
}
// We're definitely not going to have to send anything with
// early data.
if !state.Params.ClientSendingEarlyData {
toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData,
KeySet: makeTrafficKeys(params, clientHandshakeTrafficSecret)})
}
return nextState, toSend, AlertNoAlert
}
type clientStateWaitEE struct {
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
masterSecret []byte
@@ -596,6 +605,14 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,
state.handshakeHash.Write(hm.Marshal())
toSend := []HandshakeAction{}
if state.Params.ClientSendingEarlyData && !state.Params.UsingEarlyData {
// We didn't get 0-RTT, so rekey to handshake.
toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData,
KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)})
}
if state.Params.UsingPSK {
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]")
nextState := clientStateWaitFinished{
@@ -608,7 +625,7 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
return nextState, toSend, AlertNoAlert
}
logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]")
@@ -622,13 +639,13 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState,
clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret,
serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret,
}
return nextState, nil, AlertNoAlert
return nextState, toSend, AlertNoAlert
}
type clientStateWaitCertCR struct {
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
masterSecret []byte
@@ -706,7 +723,7 @@ func (state clientStateWaitCertCR) Next(hr handshakeMessageReader) (HandshakeSta
type clientStateWaitCert struct {
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
@@ -760,7 +777,7 @@ func (state clientStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState
type clientStateWaitCV struct {
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
@@ -861,7 +878,7 @@ func (state clientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState,
type clientStateWaitFinished struct {
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams
handshakeHash hash.Hash
@@ -933,6 +950,7 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
toSend := []HandshakeAction{}
if state.Params.UsingEarlyData {
logf(logTypeHandshake, "Sending end of early data")
// Note: We only send EOED if the server is actually going to use the early
// data. Otherwise, it will never see it, and the transcripts will
// mismatch.
@@ -942,10 +960,11 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
state.handshakeHash.Write(eoedm.Marshal())
logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal())
}
clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)
toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys})
// And then rekey to handshake
toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData,
KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)})
}
if state.Params.UsingClientAuth {
// Extract constraints from certicateRequest
@@ -1045,6 +1064,8 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
RekeyOut{epoch: EpochApplicationData, KeySet: clientTrafficKeys},
}...)
state.hsCtx.receivedEndOfFlight()
logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]")
nextState := stateConnected{
Params: state.Params,

View File

@@ -25,6 +25,7 @@ const (
RecordTypeAlert RecordType = 21
RecordTypeHandshake RecordType = 22
RecordTypeApplicationData RecordType = 23
RecordTypeAck RecordType = 25
)
// enum {...} HandshakeType;
@@ -166,6 +167,8 @@ const (
type State uint8
const (
StateInit = 0
// states valid for the client
StateClientStart State = iota
StateClientWaitSH
@@ -179,6 +182,7 @@ const (
StateServerStart State = iota
StateServerRecvdCH
StateServerNegotiated
StateServerReadPastEarlyData
StateServerWaitEOED
StateServerWaitFlight2
StateServerWaitCert
@@ -211,6 +215,8 @@ func (s State) String() string {
return "Server RECVD_CH"
case StateServerNegotiated:
return "Server NEGOTIATED"
case StateServerReadPastEarlyData:
return "Server READ_PAST_EARLY_DATA"
case StateServerWaitEOED:
return "Server WAIT_EOED"
case StateServerWaitFlight2:
@@ -252,3 +258,9 @@ func (e Epoch) label() string {
}
return "Application data (updated)"
}
func assert(b bool) {
if !b {
panic("Assertion failed")
}
}

View File

@@ -13,8 +13,6 @@ import (
"time"
)
var WouldBlock = fmt.Errorf("Would have blocked")
type Certificate struct {
Chain []*x509.Certificate
PrivateKey crypto.Signer
@@ -253,6 +251,8 @@ type ConnectionState struct {
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates
NextProto string // Selected ALPN proto
UsingPSK bool // Are we using PSK.
UsingEarlyData bool // Did we negotiate 0-RTT.
}
// Conn implements the net.Conn interface, as with "crypto/tls"
@@ -263,8 +263,6 @@ type Conn struct {
conn net.Conn
isClient bool
EarlyData []byte
state stateConnected
hState HandshakeState
handshakeMutex sync.Mutex
@@ -273,22 +271,27 @@ type Conn struct {
readBuffer []byte
in, out *RecordLayer
hsCtx HandshakeContext
hsCtx *HandshakeContext
}
func NewConn(conn net.Conn, config *Config, isClient bool) *Conn {
c := &Conn{conn: conn, config: config, isClient: isClient}
c := &Conn{conn: conn, config: config, isClient: isClient, hsCtx: &HandshakeContext{}}
if !config.UseDTLS {
c.in = NewRecordLayerTLS(c.conn)
c.out = NewRecordLayerTLS(c.conn)
c.hsCtx.hIn = NewHandshakeLayerTLS(c.in)
c.hsCtx.hOut = NewHandshakeLayerTLS(c.out)
c.in = NewRecordLayerTLS(c.conn, directionRead)
c.out = NewRecordLayerTLS(c.conn, directionWrite)
c.hsCtx.hIn = NewHandshakeLayerTLS(c.hsCtx, c.in)
c.hsCtx.hOut = NewHandshakeLayerTLS(c.hsCtx, c.out)
} else {
c.in = NewRecordLayerDTLS(c.conn)
c.out = NewRecordLayerDTLS(c.conn)
c.hsCtx.hIn = NewHandshakeLayerDTLS(c.in)
c.hsCtx.hOut = NewHandshakeLayerDTLS(c.out)
c.in = NewRecordLayerDTLS(c.conn, directionRead)
c.out = NewRecordLayerDTLS(c.conn, directionWrite)
c.hsCtx.hIn = NewHandshakeLayerDTLS(c.hsCtx, c.in)
c.hsCtx.hOut = NewHandshakeLayerDTLS(c.hsCtx, c.out)
c.hsCtx.timeoutMS = initialTimeout
c.hsCtx.timers = newTimerSet()
c.hsCtx.waitingNextFlight = true
}
c.in.label = c.label()
c.out.label = c.label()
c.hsCtx.hIn.nonblocking = c.config.NonBlocking
return c
}
@@ -374,20 +377,54 @@ func (c *Conn) consumeRecord() error {
return io.EOF
}
case RecordTypeAck:
if !c.hsCtx.hIn.datagram {
logf(logTypeHandshake, "Received ACK in TLS mode")
return AlertUnexpectedMessage
}
return c.hsCtx.processAck(pt.fragment)
case RecordTypeApplicationData:
c.readBuffer = append(c.readBuffer, pt.fragment...)
logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer)
}
return err
}
func readPartial(in *[]byte, buffer []byte) int {
logf(logTypeIO, "conn.Read input buffer now has len %d", len((*in)))
read := copy(buffer, *in)
*in = (*in)[read:]
logf(logTypeVerbose, "Returning %v", string(buffer))
return read
}
// Read application data up to the size of buffer. Handshake and alert records
// are consumed by the Conn object directly.
func (c *Conn) Read(buffer []byte) (int, error) {
if _, connected := c.hState.(stateConnected); !connected {
return 0, errors.New("Read called before the handshake completed")
// Clients can't call Read prior to handshake completion.
if c.isClient {
return 0, errors.New("Read called before the handshake completed")
}
// Neither can servers that don't allow early data.
if !c.config.AllowEarlyData {
return 0, errors.New("Read called before the handshake completed")
}
// If there's no early data, then return WouldBlock
if len(c.hsCtx.earlyData) == 0 {
return 0, AlertWouldBlock
}
return readPartial(&c.hsCtx.earlyData, buffer), nil
}
// The handshake is now connected.
logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer))
if alert := c.Handshake(); alert != AlertNoAlert {
return 0, alert
@@ -397,6 +434,13 @@ func (c *Conn) Read(buffer []byte) (int, error) {
return 0, nil
}
// Run our timers.
if c.config.UseDTLS {
if err := c.hsCtx.timers.check(time.Now()); err != nil {
return 0, AlertInternalError
}
}
// Lock the input channel
c.in.Lock()
defer c.in.Unlock()
@@ -406,30 +450,14 @@ func (c *Conn) Read(buffer []byte) (int, error) {
// err can be nil if consumeRecord processed a non app-data
// record.
if err != nil {
if c.config.NonBlocking || err != WouldBlock {
if c.config.NonBlocking || err != AlertWouldBlock {
logf(logTypeIO, "conn.Read returns err=%v", err)
return 0, err
}
}
}
var read int
n := len(buffer)
logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer))
if len(c.readBuffer) <= n {
buffer = buffer[:len(c.readBuffer)]
copy(buffer, c.readBuffer)
read = len(c.readBuffer)
c.readBuffer = c.readBuffer[:0]
} else {
logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n)
copy(buffer[:n], c.readBuffer[:n])
c.readBuffer = c.readBuffer[n:]
read = n
}
logf(logTypeVerbose, "Returning %v", string(buffer))
return read, nil
return readPartial(&c.readBuffer, buffer), nil
}
// Write application data
@@ -438,6 +466,10 @@ func (c *Conn) Write(buffer []byte) (int, error) {
c.out.Lock()
defer c.out.Unlock()
if !c.Writable() {
return 0, errors.New("Write called before the handshake completed (and early data not in use)")
}
// Send full-size fragments
var start int
sent := 0
@@ -549,13 +581,23 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
}
case SendQueuedHandshake:
err := c.hsCtx.hOut.SendQueuedMessages()
_, err := c.hsCtx.hOut.SendQueuedMessages()
if err != nil {
logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err)
return AlertInternalError
}
if c.config.UseDTLS {
c.hsCtx.timers.start(retransmitTimerLabel,
c.hsCtx.handshakeRetransmit,
c.hsCtx.timeoutMS)
}
case RekeyIn:
logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.epoch.label(), action.KeySet)
// Check that we don't have an input data in the handshake frame parser.
if len(c.hsCtx.hIn.frame.remainder) > 0 {
logf(logTypeHandshake, "%s Rekey with data still in handshake buffers", label)
return AlertDecodeError
}
err := c.in.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv)
if err != nil {
logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err)
@@ -570,61 +612,9 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
return AlertInternalError
}
case SendEarlyData:
logf(logTypeHandshake, "%s Sending early data...", label)
_, err := c.Write(c.EarlyData)
if err != nil {
logf(logTypeHandshake, "%s Error writing early data: %v", label, err)
return AlertInternalError
}
case ReadPastEarlyData:
logf(logTypeHandshake, "%s Reading past early data...", label)
// Scan past all records that fail to decrypt
_, err := c.in.PeekRecordType(!c.config.NonBlocking)
if err == nil {
break
}
_, ok := err.(DecryptError)
for ok {
_, err = c.in.PeekRecordType(!c.config.NonBlocking)
if err == nil {
break
}
_, ok = err.(DecryptError)
}
case ReadEarlyData:
logf(logTypeHandshake, "%s Reading early data...", label)
t, err := c.in.PeekRecordType(!c.config.NonBlocking)
if err != nil {
logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Got record type(1): %v", label, t)
for t == RecordTypeApplicationData {
// Read a record into the buffer. Note that this is safe
// in blocking mode because we read the record in in
// PeekRecordType.
pt, err := c.in.ReadRecord()
if err != nil {
logf(logTypeHandshake, "%s Error reading early data record: %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment)
c.EarlyData = append(c.EarlyData, pt.fragment...)
t, err = c.in.PeekRecordType(!c.config.NonBlocking)
if err != nil {
logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err)
return AlertInternalError
}
logf(logTypeHandshake, "%s Got record type (2): %v", label, t)
}
logf(logTypeHandshake, "%s Done reading early data", label)
case ResetOut:
logf(logTypeHandshake, "%s Rekeying out to %s seq=%v", label, EpochClear, action.seq)
c.out.ResetClear(action.seq)
case StorePSK:
logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity)
@@ -637,7 +627,8 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert {
}
default:
logf(logTypeHandshake, "%s Unknown actionuction type", label)
logf(logTypeHandshake, "%s Unknown action type", label)
assert(false)
return AlertInternalError
}
@@ -657,7 +648,6 @@ func (c *Conn) HandshakeSetup() Alert {
opts := ConnectionOptions{
ServerName: c.config.ServerName,
NextProtos: c.config.NextProtos,
EarlyData: c.EarlyData,
}
if c.isClient {
@@ -706,18 +696,21 @@ type handshakeMessageReaderImpl struct {
var _ handshakeMessageReader = &handshakeMessageReaderImpl{}
func (r *handshakeMessageReaderImpl) ReadMessage() (*HandshakeMessage, Alert) {
hm, err := r.hsCtx.hIn.ReadMessage()
if err == WouldBlock {
return nil, AlertWouldBlock
var hm *HandshakeMessage
var err error
for {
hm, err = r.hsCtx.hIn.ReadMessage()
if err == AlertWouldBlock {
return nil, AlertWouldBlock
}
if err != nil {
logf(logTypeHandshake, "Error reading message: %v", err)
return nil, AlertCloseNotify
}
if hm != nil {
break
}
}
if err != nil {
logf(logTypeHandshake, "[client] Error reading message: %v", err)
return nil, AlertCloseNotify
}
// Once you have read a message, you no longer need the outgoing queue
// for DTLS.
r.hsCtx.hOut.ClearQueuedMessages()
return hm, AlertNoAlert
}
@@ -753,14 +746,21 @@ func (c *Conn) Handshake() Alert {
state := c.hState
_, connected := state.(stateConnected)
hmr := &handshakeMessageReaderImpl{hsCtx: &c.hsCtx}
hmr := &handshakeMessageReaderImpl{hsCtx: c.hsCtx}
for !connected {
var alert Alert
var actions []HandshakeAction
// Advance the state machine
state, actions, alert = state.Next(hmr)
if alert == WouldBlock {
if alert == AlertWouldBlock {
logf(logTypeHandshake, "%s Would block reading message: %s", label, alert)
// If we blocked, then run our timers to see if any have expired.
if c.hsCtx.hIn.datagram {
if err := c.hsCtx.timers.check(time.Now()); err != nil {
return AlertInternalError
}
}
return AlertWouldBlock
}
if alert == AlertCloseNotify {
@@ -788,6 +788,34 @@ func (c *Conn) Handshake() Alert {
if connected {
c.state = state.(stateConnected)
c.handshakeComplete = true
if !c.isClient {
// Send NewSessionTicket if configured to
if c.config.SendSessionTickets {
actions, alert := c.state.NewSessionTicket(
c.config.TicketLen,
c.config.TicketLifetime,
c.config.EarlyDataLifetime)
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert)
return alert
}
}
}
// If there is early data, move it into the main buffer
if c.hsCtx.earlyData != nil {
c.readBuffer = c.hsCtx.earlyData
c.hsCtx.earlyData = nil
}
} else {
assert(c.hsCtx.earlyData == nil)
}
}
if c.config.NonBlocking {
@@ -798,23 +826,6 @@ func (c *Conn) Handshake() Alert {
}
}
// Send NewSessionTicket if acting as server
if !c.isClient && c.config.SendSessionTickets {
actions, alert := c.state.NewSessionTicket(
c.config.TicketLen,
c.config.TicketLifetime,
c.config.EarlyDataLifetime)
for _, action := range actions {
alert = c.takeAction(action)
if alert != AlertNoAlert {
logf(logTypeHandshake, "Error during handshake actions: %v", alert)
c.sendAlert(alert)
return alert
}
}
}
return AlertNoAlert
}
@@ -848,6 +859,9 @@ func (c *Conn) SendKeyUpdate(requestUpdate bool) error {
}
func (c *Conn) GetHsState() State {
if c.hState == nil {
return StateInit
}
return c.hState.State()
}
@@ -878,7 +892,30 @@ func (c *Conn) ConnectionState() ConnectionState {
state.NextProto = c.state.Params.NextProto
state.VerifiedChains = c.state.verifiedChains
state.PeerCertificates = c.state.peerCertificates
state.UsingPSK = c.state.Params.UsingPSK
state.UsingEarlyData = c.state.Params.UsingEarlyData
}
return state
}
func (c *Conn) Writable() bool {
// If we're connected, we're writable.
if _, connected := c.hState.(stateConnected); connected {
return true
}
// If we're a client in 0-RTT, then we're writable.
if c.isClient && c.out.cipher.epoch == EpochEarlyData {
return true
}
return false
}
func (c *Conn) label() string {
if c.isClient {
return "client"
}
return "server"
}

View File

@@ -2,14 +2,33 @@ package mint
import (
"fmt"
"github.com/bifurcation/mint/syntax"
"time"
)
// This file is a placeholder. DTLS-specific stuff (timer management,
// ACKs, retransmits, etc. will eventually go here.
const (
initialMtu = 1200
initialMtu = 1200
initialTimeout = 100
)
// labels for timers
const (
retransmitTimerLabel = "handshake retransmit"
ackTimerLabel = "ack timer"
)
type SentHandshakeFragment struct {
seq uint32
offset int
fragLength int
record uint64
acked bool
}
type DtlsAck struct {
RecordNumbers []uint64 `tls:"head=2"`
}
func wireVersion(h *HandshakeLayer) uint16 {
if h.datagram {
return dtls12WireVersion
@@ -26,3 +45,178 @@ func dtlsConvertVersion(version uint16) uint16 {
}
panic(fmt.Sprintf("Internal error, unexpected version=%d", version))
}
// TODO(ekr@rtfm.com): Move these to state-machine.go
func (h *HandshakeContext) handshakeRetransmit() error {
if _, err := h.hOut.SendQueuedMessages(); err != nil {
return err
}
h.timers.start(retransmitTimerLabel,
h.handshakeRetransmit,
h.timeoutMS)
// TODO(ekr@rtfm.com): Back off timer
return nil
}
func (h *HandshakeContext) sendAck() error {
toack := h.hIn.recvdRecords
count := (initialMtu - 2) / 8 // TODO(ekr@rtfm.com): Current MTU
if len(toack) > count {
toack = toack[:count]
}
logf(logTypeHandshake, "Sending ACK: [%x]", toack)
ack := &DtlsAck{toack}
body, err := syntax.Marshal(&ack)
if err != nil {
return err
}
err = h.hOut.conn.WriteRecord(&TLSPlaintext{
contentType: RecordTypeAck,
fragment: body,
})
if err != nil {
return err
}
return nil
}
func (h *HandshakeContext) processAck(data []byte) error {
// Cancel the retransmit timer because we will be resending
// and possibly re-arming later.
h.timers.cancel(retransmitTimerLabel)
ack := &DtlsAck{}
read, err := syntax.Unmarshal(data, &ack)
if err != nil {
return err
}
if len(data) != read {
return fmt.Errorf("Invalid encoding: Extra data not consumed")
}
logf(logTypeHandshake, "ACK: [%x]", ack.RecordNumbers)
for _, r := range ack.RecordNumbers {
for _, m := range h.sentFragments {
if r == m.record {
logf(logTypeHandshake, "Marking %v %v(%v) as acked",
m.seq, m.offset, m.fragLength)
m.acked = true
}
}
}
count, err := h.hOut.SendQueuedMessages()
if err != nil {
return err
}
if count == 0 {
logf(logTypeHandshake, "All messages ACKed")
h.hOut.ClearQueuedMessages()
return nil
}
// Reset the timer
h.timers.start(retransmitTimerLabel,
h.handshakeRetransmit,
h.timeoutMS)
return nil
}
func (c *Conn) GetDTLSTimeout() (bool, time.Duration) {
return c.hsCtx.timers.remaining()
}
func (h *HandshakeContext) receivedHandshakeMessage() {
logf(logTypeHandshake, "%p Received handshake, waiting for start of flight = %v", h, h.waitingNextFlight)
// This just enables tests.
if h.hIn == nil {
return
}
if !h.hIn.datagram {
return
}
if h.waitingNextFlight {
logf(logTypeHandshake, "Received the start of the flight")
// Clear the outgoing DTLS queue and terminate the retransmit timer
h.hOut.ClearQueuedMessages()
h.timers.cancel(retransmitTimerLabel)
// OK, we're not waiting any more.
h.waitingNextFlight = false
}
// Now pre-emptively arm the ACK timer if it's not armed already.
// We'll automatically dis-arm it at the end of the handshake.
if h.timers.getTimer(ackTimerLabel) == nil {
h.timers.start(ackTimerLabel, h.sendAck, h.timeoutMS/4)
}
}
func (h *HandshakeContext) receivedEndOfFlight() {
logf(logTypeHandshake, "%p Received the end of the flight", h)
if !h.hIn.datagram {
return
}
// Empty incoming queue
h.hIn.queued = nil
// Note that we are waiting for the next flight.
h.waitingNextFlight = true
// Clear the ACK queue.
h.hIn.recvdRecords = nil
// Disarm the ACK timer
h.timers.cancel(ackTimerLabel)
}
func (h *HandshakeContext) receivedFinalFlight() {
logf(logTypeHandshake, "%p Received final flight", h)
if !h.hIn.datagram {
return
}
// Disarm the ACK timer
h.timers.cancel(ackTimerLabel)
// But send an ACK immediately.
h.sendAck()
}
func (h *HandshakeContext) fragmentAcked(seq uint32, offset int, fraglen int) bool {
logf(logTypeHandshake, "Looking to see if fragment %v %v(%v) was acked", seq, offset, fraglen)
for _, f := range h.sentFragments {
if !f.acked {
continue
}
if f.seq != seq {
continue
}
if f.offset > offset {
continue
}
// At this point, we know that the stored fragment starts
// at or before what we want to send, so check where the end
// is.
if f.offset+f.fragLength < offset+fraglen {
continue
}
return true
}
return false
}

View File

@@ -67,7 +67,7 @@ func (f *frameReader) process() (hdr []byte, body []byte, err error) {
f.writeOffset += copied
if f.writeOffset < len(f.working) {
logf(logTypeVerbose, "Read would have blocked 1")
return nil, nil, WouldBlock
return nil, nil, AlertWouldBlock
}
// Reset the write offset, because we are now full.
f.writeOffset = 0
@@ -94,5 +94,5 @@ func (f *frameReader) process() (hdr []byte, body []byte, err error) {
}
logf(logTypeVerbose, "Read would have blocked 2")
return nil, nil, WouldBlock
return nil, nil, AlertWouldBlock
}

View File

@@ -35,7 +35,6 @@ type HandshakeMessage struct {
datagram bool
offset uint32 // Used for DTLS
length uint32
records []uint64 // Used for DTLS
cipher *cipherState
}
@@ -119,6 +118,7 @@ func (h *HandshakeLayer) HandshakeMessageFromBody(body HandshakeMessageBody) (*H
}
type HandshakeLayer struct {
ctx *HandshakeContext // The handshake we are attached to
nonblocking bool // Should we operate in nonblocking mode
conn *RecordLayer // Used for reading/writing records
frame *frameReader // The buffered frame reader
@@ -126,6 +126,7 @@ type HandshakeLayer struct {
msgSeq uint32 // The DTLS message sequence number
queued []*HandshakeMessage // In/out queue
sent []*HandshakeMessage // Sent messages for DTLS
recvdRecords []uint64 // Records we have received.
maxFragmentLen int
}
@@ -152,8 +153,9 @@ func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) {
return int(val), nil
}
func NewHandshakeLayerTLS(r *RecordLayer) *HandshakeLayer {
func NewHandshakeLayerTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
h := HandshakeLayer{}
h.ctx = c
h.conn = r
h.datagram = false
h.frame = newFrameReader(&handshakeLayerFrameDetails{false})
@@ -161,8 +163,9 @@ func NewHandshakeLayerTLS(r *RecordLayer) *HandshakeLayer {
return &h
}
func NewHandshakeLayerDTLS(r *RecordLayer) *HandshakeLayer {
func NewHandshakeLayerDTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer {
h := HandshakeLayer{}
h.ctx = c
h.conn = r
h.datagram = true
h.frame = newFrameReader(&handshakeLayerFrameDetails{true})
@@ -172,16 +175,25 @@ func NewHandshakeLayerDTLS(r *RecordLayer) *HandshakeLayer {
func (h *HandshakeLayer) readRecord() error {
logf(logTypeVerbose, "Trying to read record")
pt, err := h.conn.ReadRecord()
pt, err := h.conn.readRecordAnyEpoch()
if err != nil {
return err
}
if pt.contentType != RecordTypeHandshake &&
pt.contentType != RecordTypeAlert {
switch pt.contentType {
case RecordTypeHandshake, RecordTypeAlert, RecordTypeAck:
default:
return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType)
}
if pt.contentType == RecordTypeAck {
if !h.datagram {
return fmt.Errorf("tls.handshakelayer: can't have ACK with TLS")
}
logf(logTypeIO, "read ACK")
return h.ctx.processAck(pt.fragment)
}
if pt.contentType == RecordTypeAlert {
logf(logTypeIO, "read alert %v", pt.fragment[1])
if len(pt.fragment) < 2 {
@@ -191,6 +203,19 @@ func (h *HandshakeLayer) readRecord() error {
return Alert(pt.fragment[1])
}
assert(h.ctx.hIn.conn != nil)
if pt.epoch != h.ctx.hIn.conn.cipher.epoch {
// This is out of order but we're dropping it.
// TODO(ekr@rtfm.com): If server, need to retransmit Finished.
if pt.epoch == EpochClear || pt.epoch == EpochHandshakeData {
return nil
}
// Anything else shouldn't happen.
return AlertIllegalParameter
}
h.recvdRecords = append(h.recvdRecords, pt.seq)
h.frame.addChunk(pt.fragment)
return nil
@@ -227,9 +252,13 @@ func (h *HandshakeLayer) noteMessageDelivered(seq uint32) {
func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMessage, error) {
if hm.seq < h.msgSeq {
return nil, WouldBlock
return nil, nil
}
// TODO(ekr@rtfm.com): Send an ACK immediately if we got something
// out of order.
h.ctx.receivedHandshakeMessage()
if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) {
// TODO(ekr@rtfm.com): Check the length?
// This is complete.
@@ -259,12 +288,12 @@ func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMe
func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) {
if len(h.queued) == 0 {
return nil, WouldBlock
return nil, nil
}
hm := h.queued[0]
if hm.seq != h.msgSeq {
return nil, WouldBlock
return nil, nil
}
if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) {
@@ -307,7 +336,7 @@ func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) {
}
return nil, WouldBlock
return nil, nil
}
func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
@@ -315,19 +344,19 @@ func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
var err error
hm, err := h.checkMessageAvailable()
if err == nil {
return hm, err
}
if err != WouldBlock {
if err != nil {
return nil, err
}
if hm != nil {
return hm, nil
}
for {
logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder))
if h.frame.needed() > 0 {
logf(logTypeVerbose, "Trying to read a new record")
err = h.readRecord()
if err != nil && (h.nonblocking || err != WouldBlock) {
if err != nil && (h.nonblocking || err != AlertWouldBlock) {
return nil, err
}
}
@@ -336,7 +365,7 @@ func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) {
if err == nil {
break
}
if err != nil && (h.nonblocking || err != WouldBlock) {
if err != nil && (h.nonblocking || err != AlertWouldBlock) {
return nil, err
}
}
@@ -370,12 +399,13 @@ func (h *HandshakeLayer) QueueMessage(hm *HandshakeMessage) error {
return nil
}
func (h *HandshakeLayer) SendQueuedMessages() error {
func (h *HandshakeLayer) SendQueuedMessages() (int, error) {
logf(logTypeHandshake, "Sending outgoing messages")
err := h.WriteMessages(h.queued)
h.ClearQueuedMessages() // This isn't going to work for DTLS, but we'll
// get there.
return err
count, err := h.WriteMessages(h.queued)
if !h.datagram {
h.ClearQueuedMessages()
}
return count, err
}
func (h *HandshakeLayer) ClearQueuedMessages() {
@@ -383,7 +413,7 @@ func (h *HandshakeLayer) ClearQueuedMessages() {
h.queued = nil
}
func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (int, error) {
func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (bool, int, error) {
var buf []byte
// Figure out if we're going to want the full header or just
@@ -408,17 +438,35 @@ func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int
}
body := hm.body[start : start+bodylen]
// Now see if this chunk has been ACKed. This doesn't produce ideal
// retransmission but is simple.
if h.ctx.fragmentAcked(hm.seq, start, bodylen) {
logf(logTypeHandshake, "Fragment %v %v(%v) already acked. Skipping", hm.seq, start, bodylen)
return false, start + bodylen, nil
}
// Encode the data.
if hdrlen > 0 {
hm2 := *hm
hm2.offset = uint32(start)
hm2.body = body
buf = hm2.Marshal()
hm = &hm2
} else {
buf = body
}
return start + bodylen, h.conn.writeRecordWithPadding(
if h.datagram {
// Remember that we sent this.
h.ctx.sentFragments = append(h.ctx.sentFragments, &SentHandshakeFragment{
hm.seq,
start,
len(body),
h.conn.cipher.combineSeq(true),
false,
})
}
return true, start + bodylen, h.conn.writeRecordWithPadding(
&TLSPlaintext{
contentType: RecordTypeHandshake,
fragment: buf,
@@ -426,38 +474,46 @@ func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int
hm.cipher, 0)
}
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error {
func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) (int, error) {
start := int(0)
if len(hm.body) > maxHandshakeMessageLen {
return fmt.Errorf("Tried to write a handshake message that's too long")
return 0, fmt.Errorf("Tried to write a handshake message that's too long")
}
written := 0
wrote := false
// Always make one pass through to allow EOED (which is empty).
for {
var err error
start, err = h.writeFragment(hm, start, h.maxFragmentLen)
wrote, start, err = h.writeFragment(hm, start, h.maxFragmentLen)
if err != nil {
return err
return 0, err
}
if wrote {
written++
}
if start >= len(hm.body) {
break
}
}
return nil
return written, nil
}
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error {
func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) (int, error) {
written := 0
for _, hm := range hms {
logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body)
err := h.WriteMessage(hm)
wrote, err := h.WriteMessage(hm)
if err != nil {
return err
return 0, err
}
written += wrote
}
return nil
return written, nil
}
func encodeUint(v uint64, size int, out []byte) []byte {

View File

@@ -168,10 +168,11 @@ func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme
return nil, 0, fmt.Errorf("No certificates compatible with signature schemes")
}
func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool {
usingEarlyData := gotEarlyData && usingPSK && allowEarlyData
logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData)
return usingEarlyData
func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) (using bool, rejected bool) {
using = gotEarlyData && usingPSK && allowEarlyData
rejected = gotEarlyData && !using
logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v, %v", usingPSK, gotEarlyData, allowEarlyData, using, rejected)
return
}
func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) {

View File

@@ -1,7 +1,6 @@
package mint
import (
"bytes"
"crypto/cipher"
"fmt"
"io"
@@ -21,6 +20,13 @@ func (err DecryptError) Error() string {
return string(err)
}
type direction uint8
const (
directionWrite = direction(1)
directionRead = direction(2)
)
// struct {
// ContentType type;
// ProtocolVersion record_version [0301 for CH, 0303 for others]
@@ -31,20 +37,23 @@ type TLSPlaintext struct {
// Omitted: record_version (static)
// Omitted: length (computed from fragment)
contentType RecordType
epoch Epoch
seq uint64
fragment []byte
}
type cipherState struct {
epoch Epoch // DTLS epoch
ivLength int // Length of the seq and nonce fields
seq []byte // Zero-padded sequence number
seq uint64 // Zero-padded sequence number
iv []byte // Buffer for the IV
cipher cipher.AEAD // AEAD cipher
}
type RecordLayer struct {
sync.Mutex
label string
direction direction
version uint16 // The current version number
conn io.ReadWriter // The underlying connection
frame *frameReader // The buffered frame reader
@@ -52,7 +61,9 @@ type RecordLayer struct {
cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
cachedError error // Error on the last record read
cipher *cipherState
cipher *cipherState
readCiphers map[Epoch]*cipherState
datagram bool
}
@@ -76,7 +87,7 @@ func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
}
func newCipherStateNull() *cipherState {
return &cipherState{EpochClear, 0, bytes.Repeat([]byte{0}, sequenceNumberLen), nil, nil}
return &cipherState{EpochClear, 0, 0, nil, nil}
}
func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte) (*cipherState, error) {
@@ -85,11 +96,13 @@ func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte)
return nil, err
}
return &cipherState{epoch, len(iv), bytes.Repeat([]byte{0}, sequenceNumberLen), iv, cipher}, nil
return &cipherState{epoch, len(iv), 0, iv, cipher}, nil
}
func NewRecordLayerTLS(conn io.ReadWriter) *RecordLayer {
func NewRecordLayerTLS(conn io.ReadWriter, dir direction) *RecordLayer {
r := RecordLayer{}
r.label = ""
r.direction = dir
r.conn = conn
r.frame = newFrameReader(recordLayerFrameDetails{false})
r.cipher = newCipherStateNull()
@@ -97,11 +110,15 @@ func NewRecordLayerTLS(conn io.ReadWriter) *RecordLayer {
return &r
}
func NewRecordLayerDTLS(conn io.ReadWriter) *RecordLayer {
func NewRecordLayerDTLS(conn io.ReadWriter, dir direction) *RecordLayer {
r := RecordLayer{}
r.label = ""
r.direction = dir
r.conn = conn
r.frame = newFrameReader(recordLayerFrameDetails{true})
r.cipher = newCipherStateNull()
r.readCiphers = make(map[Epoch]*cipherState, 0)
r.readCiphers[0] = r.cipher
r.datagram = true
return &r
}
@@ -110,53 +127,67 @@ func (r *RecordLayer) SetVersion(v uint16) {
r.version = v
}
func (r *RecordLayer) ResetClear(seq uint64) {
r.cipher = newCipherStateNull()
r.cipher.seq = seq
}
func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []byte) error {
cipher, err := newCipherStateAead(epoch, factory, key, iv)
if err != nil {
return err
}
r.cipher = cipher
if r.datagram && r.direction == directionRead {
r.readCiphers[epoch] = cipher
}
return nil
}
func (c *cipherState) formatSeq(datagram bool) []byte {
seq := append([]byte{}, c.seq...)
// TODO(ekr@rtfm.com): This is never used, which is a bug.
func (r *RecordLayer) DiscardReadKey(epoch Epoch) {
if !r.datagram {
return
}
_, ok := r.readCiphers[epoch]
assert(ok)
delete(r.readCiphers, epoch)
}
func (c *cipherState) combineSeq(datagram bool) uint64 {
seq := c.seq
if datagram {
seq[0] = byte(c.epoch >> 8)
seq[1] = byte(c.epoch & 0xff)
seq |= uint64(c.epoch) << 48
}
return seq
}
func (c *cipherState) computeNonce(seq []byte) []byte {
func (c *cipherState) computeNonce(seq uint64) []byte {
nonce := make([]byte, len(c.iv))
copy(nonce, c.iv)
offset := len(c.iv) - len(seq)
for i, b := range seq {
nonce[i+offset] ^= b
s := seq
offset := len(c.iv)
for i := 0; i < 8; i++ {
nonce[(offset-i)-1] ^= byte(s & 0xff)
s >>= 8
}
logf(logTypeCrypto, "Computing nonce for sequence # %x -> %x", seq, nonce)
return nonce
}
func (c *cipherState) incrementSequenceNumber() {
var i int
for i = len(c.seq) - 1; i >= 0; i-- {
c.seq[i]++
if c.seq[i] != 0 {
break
}
}
if i < 0 {
if c.seq >= (1<<48 - 1) {
// Not allowed to let sequence number wrap.
// Instead, must renegotiate before it does.
// Not likely enough to bother.
// TODO(ekr@rtfm.com): Check for DTLS here
// because the limit is sooner.
// Not likely enough to bother. This is the
// DTLS limit.
panic("TLS: sequence number wraparound")
}
c.seq++
}
func (c *cipherState) overhead() int {
@@ -166,8 +197,9 @@ func (c *cipherState) overhead() int {
return c.cipher.Overhead()
}
func (r *RecordLayer) encrypt(cipher *cipherState, seq []byte, pt *TLSPlaintext, padLen int) *TLSPlaintext {
logf(logTypeIO, "Encrypt seq=[%x]", seq)
func (r *RecordLayer) encrypt(cipher *cipherState, seq uint64, pt *TLSPlaintext, padLen int) *TLSPlaintext {
assert(r.direction == directionWrite)
logf(logTypeIO, "%s Encrypt seq=[%x]", r.label, seq)
// Expand the fragment to hold contentType, padding, and overhead
originalLen := len(pt.fragment)
plaintextLen := originalLen + 1 + padLen
@@ -191,8 +223,9 @@ func (r *RecordLayer) encrypt(cipher *cipherState, seq []byte, pt *TLSPlaintext,
return out
}
func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq []byte) (*TLSPlaintext, int, error) {
logf(logTypeIO, "Decrypt seq=[%x]", seq)
func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq uint64) (*TLSPlaintext, int, error) {
assert(r.direction == directionRead)
logf(logTypeIO, "%s Decrypt seq=[%x]", r.label, seq)
if len(pt.fragment) < r.cipher.overhead() {
msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.overhead())
return nil, 0, DecryptError(msg)
@@ -207,7 +240,7 @@ func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq []byte) (*TLSPlaintext, int,
// Decrypt
_, err := r.cipher.cipher.Open(out.fragment[:0], r.cipher.computeNonce(seq), pt.fragment, nil)
if err != nil {
logf(logTypeIO, "AEAD decryption failure [%x]", pt)
logf(logTypeIO, "%s AEAD decryption failure [%x]", r.label, pt)
return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
}
@@ -222,6 +255,7 @@ func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq []byte) (*TLSPlaintext, int,
// Truncate the message to remove contentType, padding, overhead
out.fragment = out.fragment[:newLen]
out.seq = seq
return out, padLen, nil
}
@@ -230,11 +264,11 @@ func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
var err error
for {
pt, err = r.nextRecord()
pt, err = r.nextRecord(false)
if err == nil {
break
}
if !block || err != WouldBlock {
if !block || err != AlertWouldBlock {
return 0, err
}
}
@@ -242,7 +276,7 @@ func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
}
func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
pt, err := r.nextRecord()
pt, err := r.nextRecord(false)
// Consume the cached record if there was one
r.cachedRecord = nil
@@ -251,10 +285,20 @@ func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
return pt, err
}
func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
func (r *RecordLayer) readRecordAnyEpoch() (*TLSPlaintext, error) {
pt, err := r.nextRecord(true)
// Consume the cached record if there was one
r.cachedRecord = nil
r.cachedError = nil
return pt, err
}
func (r *RecordLayer) nextRecord(allowOldEpoch bool) (*TLSPlaintext, error) {
cipher := r.cipher
if r.cachedRecord != nil {
logf(logTypeIO, "Returning cached record")
logf(logTypeIO, "%s Returning cached record", r.label)
return r.cachedRecord, r.cachedError
}
@@ -262,9 +306,10 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
//
// 1. We get a frame
// 2. We try to read off the socket and get nothing, in which case
// return WouldBlock
// returnAlertWouldBlock
// 3. We get an error.
err := WouldBlock
var err error
err = AlertWouldBlock
var header, body []byte
for err != nil {
@@ -272,24 +317,24 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen)
n, err := r.conn.Read(buf)
if err != nil {
logf(logTypeIO, "Error reading, %v", err)
logf(logTypeIO, "%s Error reading, %v", r.label, err)
return nil, err
}
if n == 0 {
return nil, WouldBlock
return nil, AlertWouldBlock
}
logf(logTypeIO, "Read %v bytes", n)
logf(logTypeIO, "%s Read %v bytes", r.label, n)
buf = buf[:n]
r.frame.addChunk(buf)
}
header, body, err = r.frame.process()
// Loop around on WouldBlock to see if some
// Loop around onAlertWouldBlock to see if some
// data is now available.
if err != nil && err != WouldBlock {
if err != nil && err != AlertWouldBlock {
return nil, err
}
}
@@ -299,7 +344,7 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
switch RecordType(header[0]) {
default:
return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData:
case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData, RecordTypeAck:
pt.contentType = RecordType(header[0])
}
@@ -318,28 +363,48 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) {
pt.fragment = make([]byte, size)
copy(pt.fragment, body)
// TODO(ekr@rtfm.com): Enforce that for epoch > 0, the content type is app data.
// Attempt to decrypt fragment
if cipher.cipher != nil {
seq := cipher.seq
if r.datagram {
seq = header[3:11]
}
// TODO(ekr@rtfm.com): Handle the wrong epoch.
seq := cipher.seq
if r.datagram {
// TODO(ekr@rtfm.com): Handle duplicates.
logf(logTypeIO, "RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", cipher.epoch.label(), seq, pt.contentType, pt.fragment)
seq, _ = decodeUint(header[3:11], 8)
epoch := Epoch(seq >> 48)
// Look up the cipher suite from the epoch
c, ok := r.readCiphers[epoch]
if !ok {
logf(logTypeIO, "%s Message from unknown epoch: [%v]", r.label, epoch)
return nil, AlertWouldBlock
}
if epoch != cipher.epoch {
logf(logTypeIO, "%s Message from non-current epoch: [%v != %v] out-of-epoch reads=%v", r.label, epoch,
cipher.epoch, allowOldEpoch)
if !allowOldEpoch {
return nil, AlertWouldBlock
}
cipher = c
}
}
if cipher.cipher != nil {
logf(logTypeIO, "%s RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), seq, pt.contentType, pt.fragment)
pt, _, err = r.decrypt(pt, seq)
if err != nil {
logf(logTypeIO, "Decryption failed")
logf(logTypeIO, "%s Decryption failed", r.label)
return nil, err
}
}
pt.epoch = cipher.epoch
// Check that plaintext length is not too long
if len(pt.fragment) > maxFragmentLen {
return nil, fmt.Errorf("tls.record: Plaintext size too big")
}
logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment)
logf(logTypeIO, "%s RecordLayer.ReadRecord [%d] [%x]", r.label, pt.contentType, pt.fragment)
r.cachedRecord = pt
cipher.incrementSequenceNumber()
@@ -355,10 +420,9 @@ func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error
}
func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherState, padLen int) error {
seq := cipher.formatSeq(r.datagram)
seq := cipher.combineSeq(r.datagram)
if cipher.cipher != nil {
logf(logTypeIO, "RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
pt = r.encrypt(cipher, seq, pt, padLen)
} else if padLen > 0 {
return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
@@ -376,16 +440,17 @@ func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherSta
byte(r.version >> 8), byte(r.version & 0xff),
byte(length >> 8), byte(length)}
} else {
header = make([]byte, 13)
version := dtlsConvertVersion(r.version)
header = []byte{byte(pt.contentType),
copy(header, []byte{byte(pt.contentType),
byte(version >> 8), byte(version & 0xff),
seq[0], seq[1], seq[2], seq[3],
seq[4], seq[5], seq[6], seq[7],
byte(length >> 8), byte(length)}
})
encodeUint(seq, 8, header[3:])
encodeUint(uint64(length), 2, header[11:])
}
record := append(header, pt.fragment...)
logf(logTypeIO, "RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
cipher.incrementSequenceNumber()
_, err := r.conn.Write(record)

View File

@@ -24,14 +24,17 @@ import (
// | [Send CertificateRequest]
// Can send | [Send Certificate + CertificateVerify]
// app data --> | Send Finished
// after +--------+--------+
// here No 0-RTT | | 0-RTT
// | v
// | WAIT_EOED <---+
// | Recv | | | Recv
// | EndOfEarlyData | | | early data
// | | +-----+
// +> WAIT_FLIGHT2 <-+
// after here |
// +-----------+--------+
// | | |
// Rejected 0-RTT | No | | 0-RTT
// | 0-RTT | |
// | | v
// +---->READ_PAST | WAIT_EOED <---+
// Decrypt | | | Decrypt | Recv | | | Recv
// error | | | OK + HS | EOED | | | early data
// +-----+ | V | +-----+
// +---> WAIT_FLIGHT2 <-+
// |
// +--------+--------+
// No auth | | Client auth
@@ -50,16 +53,17 @@ import (
//
// NB: Not using state RECVD_CH
//
// State Instructions
// START {}
// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)]
// WAIT_EOED RekeyIn;
// WAIT_FLIGHT2 {}
// WAIT_CERT_CR {}
// WAIT_CERT {}
// WAIT_CV {}
// WAIT_FINISHED RekeyIn; RekeyOut;
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
// State Instructions
// START {}
// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)]
// WAIT_EOED RekeyIn;
// READ_PAST {}
// WAIT_FLIGHT2 {}
// WAIT_CERT_CR {}
// WAIT_CERT {}
// WAIT_CV {}
// WAIT_FINISHED RekeyIn; RekeyOut;
// CONNECTED StoreTicket || (RekeyIn; [RekeyOut])
// A cookie can be sent to the client in a HRR.
type cookie struct {
@@ -74,7 +78,7 @@ type cookie struct {
type serverStateStart struct {
Config *Config
conn *Conn
hsCtx HandshakeContext
hsCtx *HandshakeContext
}
var _ HandshakeState = &serverStateStart{}
@@ -235,10 +239,6 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err)
return nil, nil, AlertInternalError
}
if clientSentCookie && initialCipherSuite.Suite != params.Suite {
logf(logTypeHandshake, "[ServerStateStart] Would have selected a different CipherSuite after receiving the client's Cookie")
return nil, nil, AlertInternalError
}
}
// Figure out if we actually should do DH / PSK
@@ -361,7 +361,7 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
// Figure out if we're going to do early data
var clientEarlyTrafficSecret []byte
connParams.ClientSendingEarlyData = foundExts[ExtensionTypeEarlyData]
connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, foundExts[ExtensionTypeEarlyData], state.Config.AllowEarlyData)
connParams.UsingEarlyData, connParams.RejectedEarlyData = EarlyDataNegotiation(connParams.UsingPSK, foundExts[ExtensionTypeEarlyData], state.Config.AllowEarlyData)
if connParams.UsingEarlyData {
h := params.Hash.New()
h.Write(clientHello.Marshal())
@@ -379,6 +379,8 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [
return nil, nil, AlertNoApplicationProtocol
}
state.hsCtx.receivedEndOfFlight()
logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]")
state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2.
return serverStateNegotiated{
@@ -445,7 +447,7 @@ func (state *serverStateStart) generateHRR(cs CipherSuite, legacySessionId []byt
type serverStateNegotiated struct {
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
dhGroup NamedGroup
dhPublic []byte
dhSecret []byte
@@ -731,7 +733,6 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat
}
toSend = append(toSend, []HandshakeAction{
RekeyIn{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys},
ReadEarlyData{},
}...)
return nextState, toSend, AlertNoAlert
}
@@ -739,9 +740,9 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat
logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]")
toSend = append(toSend, []HandshakeAction{
RekeyIn{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys},
ReadPastEarlyData{},
}...)
waitFlight2 := serverStateWaitFlight2{
var nextState HandshakeState
nextState = serverStateWaitFlight2{
Config: state.Config,
Params: state.Params,
hsCtx: state.hsCtx,
@@ -753,13 +754,19 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat
serverTrafficSecret: serverTrafficSecret,
exporterSecret: exporterSecret,
}
return waitFlight2, toSend, AlertNoAlert
if state.Params.RejectedEarlyData {
nextState = serverStateReadPastEarlyData{
hsCtx: state.hsCtx,
next: &nextState,
}
}
return nextState, toSend, AlertNoAlert
}
type serverStateWaitEOED struct {
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
@@ -776,6 +783,38 @@ func (state serverStateWaitEOED) State() State {
}
func (state serverStateWaitEOED) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
for {
logf(logTypeHandshake, "Server reading early data...")
assert(state.hsCtx.hIn.conn.cipher.epoch == EpochEarlyData)
t, err := state.hsCtx.hIn.conn.PeekRecordType(!state.hsCtx.hIn.nonblocking)
if err == AlertWouldBlock {
return nil, nil, AlertWouldBlock
}
if err != nil {
logf(logTypeHandshake, "Server Error reading record type (1): %v", err)
return nil, nil, AlertBadRecordMAC
}
logf(logTypeHandshake, "Server got record type(1): %v", t)
if t != RecordTypeApplicationData {
break
}
// Read a record into the buffer. Note that this is safe
// in blocking mode because we read the record in
// PeekRecordType.
pt, err := state.hsCtx.hIn.conn.ReadRecord()
if err != nil {
logf(logTypeHandshake, "Server error reading early data record: %v", err)
return nil, nil, AlertInternalError
}
logf(logTypeHandshake, "Server read early data: %x", pt.fragment)
state.hsCtx.earlyData = append(state.hsCtx.earlyData, pt.fragment...)
}
hm, alert := hr.ReadMessage()
if alert != AlertNoAlert {
return nil, nil, alert
@@ -813,10 +852,44 @@ func (state serverStateWaitEOED) Next(hr handshakeMessageReader) (HandshakeState
return waitFlight2, toSend, AlertNoAlert
}
var _ HandshakeState = &serverStateReadPastEarlyData{}
type serverStateReadPastEarlyData struct {
hsCtx *HandshakeContext
next *HandshakeState
}
func (state serverStateReadPastEarlyData) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) {
for {
logf(logTypeHandshake, "Server reading past early data...")
// Scan past all records that fail to decrypt
_, err := state.hsCtx.hIn.conn.PeekRecordType(!state.hsCtx.hIn.nonblocking)
if err == nil {
break
}
if err == AlertWouldBlock {
return nil, nil, AlertWouldBlock
}
// Continue on DecryptError
_, ok := err.(DecryptError)
if !ok {
return nil, nil, AlertInternalError // Really need something else.
}
}
return *state.next, nil, AlertNoAlert
}
func (state serverStateReadPastEarlyData) State() State {
return StateServerReadPastEarlyData
}
type serverStateWaitFlight2 struct {
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
@@ -868,7 +941,7 @@ func (state serverStateWaitFlight2) Next(_ handshakeMessageReader) (HandshakeSta
type serverStateWaitCert struct {
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams
masterSecret []byte
clientHandshakeTrafficSecret []byte
@@ -940,7 +1013,7 @@ func (state serverStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState
type serverStateWaitCV struct {
Config *Config
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams
masterSecret []byte
@@ -1023,7 +1096,7 @@ func (state serverStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState,
type serverStateWaitFinished struct {
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
cryptoParams CipherSuiteParams
masterSecret []byte
@@ -1082,6 +1155,8 @@ func (state serverStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS
// Compute client traffic keys
clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret)
state.hsCtx.receivedFinalFlight()
logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]")
nextState := stateConnected{
Params: state.Params,

View File

@@ -17,10 +17,6 @@ type SendQueuedHandshake struct{}
type SendEarlyData struct{}
type ReadEarlyData struct{}
type ReadPastEarlyData struct{}
type RekeyIn struct {
epoch Epoch
KeySet keySet
@@ -31,6 +27,10 @@ type RekeyOut struct {
KeySet keySet
}
type ResetOut struct {
seq uint64
}
type StorePSK struct {
PSK PreSharedKey
}
@@ -50,7 +50,6 @@ type AppExtensionHandler interface {
type ConnectionOptions struct {
ServerName string
NextProtos []string
EarlyData []byte
}
// ConnectionParameters objects represent the parameters negotiated for a
@@ -60,6 +59,7 @@ type ConnectionParameters struct {
UsingDH bool
ClientSendingEarlyData bool
UsingEarlyData bool
RejectedEarlyData bool
UsingClientAuth bool
CipherSuite CipherSuite
@@ -69,7 +69,13 @@ type ConnectionParameters struct {
// Working state for the handshake.
type HandshakeContext struct {
hIn, hOut *HandshakeLayer
timeoutMS uint32
timers *timerSet
recvdRecords []uint64
sentFragments []*SentHandshakeFragment
hIn, hOut *HandshakeLayer
waitingNextFlight bool
earlyData []byte
}
func (hc *HandshakeContext) SetVersion(version uint16) {
@@ -84,7 +90,7 @@ func (hc *HandshakeContext) SetVersion(version uint16) {
// stateConnected is symmetric between client and server
type stateConnected struct {
Params ConnectionParameters
hsCtx HandshakeContext
hsCtx *HandshakeContext
isClient bool
cryptoParams CipherSuiteParams
resumptionSecret []byte

122
vendor/github.com/bifurcation/mint/timer.go generated vendored Normal file
View File

@@ -0,0 +1,122 @@
package mint
import (
"time"
)
// This is a simple timer implementation. Timers are stored in a sorted
// list.
// TODO(ekr@rtfm.com): Add a way to uncouple these from the system
// clock.
type timerCb func() error
type timer struct {
label string
cb timerCb
deadline time.Time
duration uint32
}
type timerSet struct {
ts []*timer
}
func newTimerSet() *timerSet {
return &timerSet{}
}
func (ts *timerSet) start(label string, cb timerCb, delayMs uint32) *timer {
now := time.Now()
t := timer{
label,
cb,
now.Add(time.Millisecond * time.Duration(delayMs)),
delayMs,
}
logf(logTypeHandshake, "Timer %s set [%v -> %v]", t.label, now, t.deadline)
var i int
ntimers := len(ts.ts)
for i = 0; i < ntimers; i++ {
if t.deadline.Before(ts.ts[i].deadline) {
break
}
}
tmp := make([]*timer, 0, ntimers+1)
tmp = append(tmp, ts.ts[:i]...)
tmp = append(tmp, &t)
tmp = append(tmp, ts.ts[i:]...)
ts.ts = tmp
return &t
}
// TODO(ekr@rtfm.com): optimize this now that the list is sorted.
// We should be able to do just one list manipulation, as long
// as we're careful about how we handle inserts during callbacks.
func (ts *timerSet) check(now time.Time) error {
for i, t := range ts.ts {
if now.After(t.deadline) {
ts.ts = append(ts.ts[:i], ts.ts[:i+1]...)
if t.cb != nil {
logf(logTypeHandshake, "Timer %s expired [%v > %v]", t.label, now, t.deadline)
cb := t.cb
t.cb = nil
err := cb()
if err != nil {
return err
}
}
} else {
break
}
}
return nil
}
// Returns the next time any of the timers would fire.
func (ts *timerSet) remaining() (bool, time.Duration) {
for _, t := range ts.ts {
if t.cb != nil {
return true, time.Until(t.deadline)
}
}
return false, time.Duration(0)
}
func (ts *timerSet) cancel(label string) {
for _, t := range ts.ts {
if t.label == label {
t.cancel()
}
}
}
func (ts *timerSet) getTimer(label string) *timer {
for _, t := range ts.ts {
if t.label == label && t.cb != nil {
return t
}
}
return nil
}
func (ts *timerSet) getAllTimers() []string {
var ret []string
for _, t := range ts.ts {
if t.cb != nil {
ret = append(ret, t.label)
}
}
return ret
}
func (t *timer) cancel() {
logf(logTypeHandshake, "Timer %s cancelled", t.label)
t.cb = nil
t.label = ""
}