package idempotency_test

import (
	"strconv"
	"sync/atomic"
	"testing"
	"time"

	"github.com/gofiber/fiber/v3/middleware/idempotency"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// go test -run Test_MemoryLock
func Test_MemoryLock(t *testing.T) {
	t.Parallel()

	l := idempotency.NewMemoryLock()

	// Test that a lock can be acquired
	{
		err := l.Lock("a")
		require.NoError(t, err)
	}

	// Test that the same lock cannot be acquired again while held
	{
		done := make(chan struct{})
		go func() {
			defer close(done)

			err := l.Lock("a")
			assert.NoError(t, err)
		}()

		select {
		case <-done:
			t.Fatal("lock acquired again")
		case <-time.After(time.Second):
			// Expected: goroutine should still be blocked
		}
	}

	// Release lock "a" to prevent goroutine leak
	{
		err := l.Unlock("a")
		require.NoError(t, err)
	}

	// Test lock and unlock sequence
	{
		err := l.Lock("b")
		require.NoError(t, err)
	}
	{
		err := l.Unlock("b")
		require.NoError(t, err)
	}
	{
		err := l.Lock("b")
		require.NoError(t, err)
	}
	{
		err := l.Unlock("b")
		require.NoError(t, err)
	}

	// Test unlocking non-existent lock (should succeed)
	{
		err := l.Unlock("c")
		require.NoError(t, err)
	}

	// Test another lock
	{
		err := l.Lock("d")
		require.NoError(t, err)
	}
	{
		err := l.Unlock("d")
		require.NoError(t, err)
	}
}

func Benchmark_MemoryLock(b *testing.B) {
	keys := make([]string, 50_000_000)
	for i := range keys {
		keys[i] = strconv.Itoa(i)
	}

	lock := idempotency.NewMemoryLock()

	for i := 0; b.Loop(); i++ {
		key := keys[i]
		if err := lock.Lock(key); err != nil {
			b.Fatal(err)
		}
		if err := lock.Unlock(key); err != nil {
			b.Fatal(err)
		}
	}
}

func Benchmark_MemoryLock_Parallel(b *testing.B) {
	// In order to prevent using repeated keys I pre-allocate keys
	keys := make([]string, 1_000_000)
	for i := range keys {
		keys[i] = strconv.Itoa(i)
	}

	b.Run("UniqueKeys", func(b *testing.B) {
		lock := idempotency.NewMemoryLock()
		var keyI atomic.Int32
		b.ReportAllocs()
		b.ResetTimer()
		b.RunParallel(func(p *testing.PB) {
			for p.Next() {
				i := int(keyI.Add(1)) % len(keys)
				key := keys[i]
				if err := lock.Lock(key); err != nil {
					b.Fatal(err)
				}
				if err := lock.Unlock(key); err != nil {
					b.Fatal(err)
				}
			}
		})
	})

	b.Run("RepeatedKeys", func(b *testing.B) {
		lock := idempotency.NewMemoryLock()
		var keyI atomic.Int32
		b.ReportAllocs()
		b.ResetTimer()
		b.RunParallel(func(p *testing.PB) {
			for p.Next() {
				// Division by 3 ensures that index will be repeated exactly 3 times
				i := int(keyI.Add(1)) / 3 % len(keys)
				key := keys[i]
				if err := lock.Lock(key); err != nil {
					b.Fatal(err)
				}
				if err := lock.Unlock(key); err != nil {
					b.Fatal(err)
				}
			}
		})
	})
}
