package middleware

import (
	"context"
	"encoding/json"
	"fmt"
	"strings"
	"time"

	"lune/talentscale/infra/cache"
	"lune/talentscale/internal/domain"
	"lune/talentscale/pkg/response"

	"github.com/gofiber/fiber/v2"
	"github.com/golang-jwt/jwt/v5"
	"github.com/google/uuid"
)

// AuthMiddleware validates JWT and injects user_id, company_id, role_id into Locals
func AuthMiddleware(jwtSecret string) fiber.Handler {
	return func(c *fiber.Ctx) error {
		authHeader := c.Get("Authorization")
		if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
			return response.Unauthorized(c, "Missing or invalid Authorization header")
		}

		tokenString := strings.TrimPrefix(authHeader, "Bearer ")
		token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
			if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
				return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
			}
			return []byte(jwtSecret), nil
		})

		if err != nil || !token.Valid {
			return response.Unauthorized(c, "Invalid or expired token")
		}

		claims, ok := token.Claims.(jwt.MapClaims)
		if !ok {
			return response.Unauthorized(c, "Invalid token claims")
		}

		// Inject into context for downstream use
		c.Locals("user_id", claims["user_id"])
		c.Locals("company_id", claims["company_id"])
		c.Locals("role_id", claims["role_id"])

		return c.Next()
	}
}

// PermissionMiddleware checks if the user has the required permission
func PermissionMiddleware(roleRepo interface {
	GetPermissionsByRoleID(ctx context.Context, roleID uuid.UUID) ([]domain.Permission, error)
}, requiredPermission string) fiber.Handler {
	return func(c *fiber.Ctx) error {
		userIDStr, ok := c.Locals("user_id").(string)
		if !ok || userIDStr == "" {
			return response.Forbidden(c, "Access forbidden: missing user identity")
		}

		roleIDStr, ok := c.Locals("role_id").(string)
		if !ok || roleIDStr == "" {
			return response.Forbidden(c, "Access forbidden: missing role identity")
		}

		roleID, err := uuid.Parse(roleIDStr)
		if err != nil {
			return response.Forbidden(c, "Access forbidden: malformed role identity")
		}

		var permissionNames []string
		cacheKey := fmt.Sprintf("user:permissions:%s", userIDStr)

		// 1. Try to get permissions from Redis Cache
		cachedPerms, err := cache.Get(cacheKey)
		if err == nil && cachedPerms != "" {
			if err := json.Unmarshal([]byte(cachedPerms), &permissionNames); err != nil {
				permissionNames = nil
			}
		}

		// 2. Fallback to DB
		if permissionNames == nil {
			perms, err := roleRepo.GetPermissionsByRoleID(c.Context(), roleID)
			if err != nil {
				return response.Forbidden(c, "Access forbidden: account permission error")
			}

			for _, p := range perms {
				permissionNames = append(permissionNames, p.Name)
			}

			// 3. Update Cache (valid for 1 hour)
			if permJSON, err := json.Marshal(permissionNames); err == nil {
				_ = cache.Set(cacheKey, string(permJSON), 1*time.Hour)
			}
		}

		// 4. Permission Validation
		hasAccess := false
		for _, name := range permissionNames {
			if name == requiredPermission {
				hasAccess = true
				break
			}
		}

		if !hasAccess {
			return response.Forbidden(c, fmt.Sprintf("Access forbidden: restricted action (%s)", requiredPermission))
		}

		return c.Next()
	}
}
