forked from quic-go/quic-go
use ASN1 to marshal source address tokens
This commit is contained in:
@@ -5,45 +5,18 @@ import (
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// StkSource is used to create and verify source address tokens
|
||||
type StkSource interface {
|
||||
// NewToken creates a new token for a given IP address
|
||||
NewToken(sourceAddress []byte) ([]byte, error)
|
||||
// DecodeToken decodes a token and returns the source address and the timestamp
|
||||
DecodeToken(data []byte) ([]byte, time.Time, error)
|
||||
}
|
||||
|
||||
type sourceAddressToken struct {
|
||||
data []byte
|
||||
// unix timestamp in seconds
|
||||
timestamp uint64
|
||||
}
|
||||
|
||||
func (t *sourceAddressToken) serialize() []byte {
|
||||
res := make([]byte, 8+len(t.data))
|
||||
binary.LittleEndian.PutUint64(res, t.timestamp)
|
||||
copy(res[8:], t.data)
|
||||
return res
|
||||
}
|
||||
|
||||
func parseToken(data []byte) (*sourceAddressToken, error) {
|
||||
if len(data) < 8 {
|
||||
return nil, fmt.Errorf("STK too short: %d", len(data))
|
||||
}
|
||||
return &sourceAddressToken{
|
||||
data: data[8:],
|
||||
timestamp: binary.LittleEndian.Uint64(data),
|
||||
}, nil
|
||||
// NewToken creates a new token
|
||||
NewToken([]byte) ([]byte, error)
|
||||
// DecodeToken decodes a token
|
||||
DecodeToken([]byte) ([]byte, error)
|
||||
}
|
||||
|
||||
type stkSource struct {
|
||||
@@ -78,30 +51,19 @@ func NewStkSource() (StkSource, error) {
|
||||
}
|
||||
|
||||
func (s *stkSource) NewToken(data []byte) ([]byte, error) {
|
||||
return encryptToken(s.aead, &sourceAddressToken{
|
||||
data: data,
|
||||
timestamp: uint64(time.Now().Unix()),
|
||||
})
|
||||
nonce := make([]byte, stkNonceSize)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.aead.Seal(nonce, nonce, data, nil), nil
|
||||
}
|
||||
|
||||
func (s *stkSource) DecodeToken(data []byte) ([]byte, time.Time, error) {
|
||||
if len(data) < stkNonceSize {
|
||||
return nil, time.Time{}, errors.New("STK too short")
|
||||
func (s *stkSource) DecodeToken(p []byte) ([]byte, error) {
|
||||
if len(p) < stkNonceSize {
|
||||
return nil, fmt.Errorf("STK too short: %d", len(p))
|
||||
}
|
||||
nonce := data[:stkNonceSize]
|
||||
res, err := s.aead.Open(nil, nonce, data[stkNonceSize:], nil)
|
||||
if err != nil {
|
||||
return nil, time.Time{}, err
|
||||
}
|
||||
|
||||
token, err := parseToken(res)
|
||||
if err != nil {
|
||||
return nil, time.Time{}, err
|
||||
}
|
||||
if token.timestamp > math.MaxInt64 {
|
||||
return nil, time.Time{}, errors.New("invalid timestamp")
|
||||
}
|
||||
return token.data, time.Unix(int64(token.timestamp), 0), nil
|
||||
nonce := p[:stkNonceSize]
|
||||
return s.aead.Open(nil, nonce, p[stkNonceSize:], nil)
|
||||
}
|
||||
|
||||
func deriveKey(secret []byte) ([]byte, error) {
|
||||
@@ -112,11 +74,3 @@ func deriveKey(secret []byte) ([]byte, error) {
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func encryptToken(aead cipher.AEAD, token *sourceAddressToken) ([]byte, error) {
|
||||
nonce := make([]byte, stkNonceSize)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return aead.Seal(nonce, nonce, token.serialize(), nil), nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user