package obfs

import (
	"fmt"
	"math/rand"
	"sync"
	"time"

	"golang.org/x/crypto/blake2b"
)

const (
	smPSKMinLen = 4
	smSaltLen   = 8
	smKeyLen    = blake2b.Size256
)

var _ Obfuscator = (*SalamanderObfuscator)(nil)

var ErrPSKTooShort = fmt.Errorf("PSK must be at least %d bytes", smPSKMinLen)

// SalamanderObfuscator is an obfuscator that obfuscates each packet with
// the BLAKE2b-256 hash of a pre-shared key combined with a random salt.
// Packet format: [8-byte salt][payload]
type SalamanderObfuscator struct {
	PSK     []byte
	RandSrc *rand.Rand

	lk sync.Mutex
}

func NewSalamanderObfuscator(psk []byte) (*SalamanderObfuscator, error) {
	if len(psk) < smPSKMinLen {
		return nil, ErrPSKTooShort
	}
	return &SalamanderObfuscator{
		PSK:     psk,
		RandSrc: rand.New(rand.NewSource(time.Now().UnixNano())),
	}, nil
}

func (o *SalamanderObfuscator) Obfuscate(in, out []byte) int {
	outLen := len(in) + smSaltLen
	if len(out) < outLen {
		return 0
	}
	o.lk.Lock()
	_, _ = o.RandSrc.Read(out[:smSaltLen])
	o.lk.Unlock()
	key := o.key(out[:smSaltLen])
	for i, c := range in {
		out[i+smSaltLen] = c ^ key[i%smKeyLen]
	}
	return outLen
}

func (o *SalamanderObfuscator) Deobfuscate(in, out []byte) int {
	outLen := len(in) - smSaltLen
	if outLen <= 0 || len(out) < outLen {
		return 0
	}
	key := o.key(in[:smSaltLen])
	for i, c := range in[smSaltLen:] {
		out[i] = c ^ key[i%smKeyLen]
	}
	return outLen
}

func (o *SalamanderObfuscator) key(salt []byte) [smKeyLen]byte {
	return blake2b.Sum256(append(o.PSK, salt...))
}