-- Centinel Bot Protection - Apache HTTP Server Module
-- Version: 1.1.0
-- Compatible with Apache 2.4+ with mod_lua

local cjson = require("cjson.safe")
local curl = require("lcurl")

local _M = {
    _VERSION = "1.1.0"
}

-- Default configuration
local config = {
    secret_key = nil,
    validator_url = "https://validator.centinelanalytica.com/validate",
    timeout_ms = 2000,
    connect_timeout_ms = 2000,
    log_enabled = true,
    fail_open = true,
    debug = false,
    ssl_verify = true,

    -- Path protection patterns
    protected_paths = {}, -- Empty = protect all
    unprotected_paths = {
        -- Video
        "%.avi$", "%.flv$", "%.mka$", "%.mkv$", "%.mov$",
        "%.mp4$", "%.mpeg$", "%.mpg$",
        -- Audio
        "%.mp3$", "%.flac$", "%.ogg$", "%.ogm$", "%.opus$",
        "%.wav$", "%.webm$",
        -- Images
        "%.webp$", "%.bmp$", "%.gif$", "%.ico$", "%.jpeg$",
        "%.jpg$", "%.png$", "%.svg$", "%.svgz$", "%.swf$",
        "%.avif$",
        -- Fonts
        "%.eot$", "%.otf$", "%.ttf$", "%.woff$", "%.woff2$",
        -- Code/Assets
        "%.css$", "%.less$", "%.js$", "%.map$", "%.json$",
        -- Archives
        "%.gz$", "%.zip$",
        -- Documents
        "%.xml$"
    }
}

-- Backoff state (per-process since mod_lua uses prefork MPM)
local backoff_state = {
    failures = 0,
    backoff_until = 0
}

-- Backoff configuration
local INITIAL_BACKOFF_MS = 1000
local MAX_BACKOFF_MS = 300000 -- 5 minutes
local BACKOFF_MULTIPLIER = 2

-- Persistent curl handle for connection reuse (one per Apache child process).
-- libcurl automatically reuses the underlying TCP + TLS connection when the
-- same easy handle calls perform() repeatedly against the same host.
-- With HTTP/2, this also means multiplexed streams over a single connection.
local curl_handle = nil

--------------------------------------------------------------------------------
-- Utility Functions
--------------------------------------------------------------------------------

local function log(r, level, msg)
    if not config.log_enabled then return end
    if level == "debug" and not config.debug then return end

    local prefix = "[Centinel] "
    if level == "error" then
        r:err(prefix .. msg)
    elseif level == "warn" then
        r:warn(prefix .. msg)
    elseif level == "debug" then
        -- Use notice level for debug messages since Apache's LogLevel
        -- filtering is separate from our module's debug toggle
        r:notice(prefix .. "[DEBUG] " .. msg)
    else
        r:notice(prefix .. msg)
    end
end

local function base64_decode(data)
    if not data then return nil end
    -- Use MIME module from luasocket for base64 decoding
    local ok, mime = pcall(require, "mime")
    if ok then
        local decoded = mime.unb64(data)
        return decoded
    end
    return nil
end

local function get_client_ip(r)
    local headers = r.headers_in

    -- X-Forwarded-For (first IP in chain)
    local xff = headers["X-Forwarded-For"]
    if xff then
        local ip = xff:match("^([^,]+)")
        if ip then return ip:gsub("^%s+", ""):gsub("%s+$", "") end
    end

    -- X-Real-IP
    local xri = headers["X-Real-IP"]
    if xri then return xri end

    -- CF-Connecting-IP (Cloudflare)
    local cfip = headers["CF-Connecting-IP"]
    if cfip then return cfip end

    -- True-Client-IP (Akamai, Cloudflare Enterprise)
    local tcip = headers["True-Client-IP"]
    if tcip then return tcip end

    -- Fall back to direct connection IP
    return r.useragent_ip
end

local function extract_centinel_cookie(r)
    local cookie_header = r.headers_in["Cookie"]
    if not cookie_header then return nil end

    local pattern = "_centinel=([^;]+)"
    local cookie_value = cookie_header:match(pattern)
    return cookie_value
end

local function cookie_obj_to_string(cookie)
    local parts = { cookie.name .. "=" .. cookie.value }
    if cookie.path and cookie.path ~= "" then
        table.insert(parts, "Path=" .. cookie.path)
    end
    if cookie.domain and cookie.domain ~= "" then
        table.insert(parts, "Domain=" .. cookie.domain)
    end
    return table.concat(parts, "; ")
end

local function set_response_cookies(r, cookies)
    if not cookies or type(cookies) ~= "table" or #cookies == 0 then return end
    for _, cookie in ipairs(cookies) do
        local cookie_str
        if type(cookie) == "table" then
            cookie_str = cookie_obj_to_string(cookie)
        else
            cookie_str = tostring(cookie)
        end
        -- Apache mod_lua: use err_headers_out for Set-Cookie to survive redirects
        r.err_headers_out["Set-Cookie"] = cookie_str
    end
end

local function get_full_url(r)
    local scheme = r.is_https and "https" or "http"
    local host = r.hostname or "localhost"
    local uri = r.unparsed_uri or r.uri or "/"
    return scheme .. "://" .. host .. uri
end

local function get_request_headers(r)
    -- Build headers table from known header names
    -- mod_lua's r.headers_in is userdata, not iterable with pairs()
    local headers = {}
    local common_headers = {
        "Host", "User-Agent", "Accept", "Accept-Language",
        "Accept-Encoding", "Connection", "Cookie", "Referer",
        "X-Forwarded-For", "X-Real-IP", "CF-Connecting-IP",
        "True-Client-IP", "X-Requested-With", "Origin",
        "Content-Type", "Content-Length", "Authorization",
        "Cache-Control", "Pragma", "If-None-Match",
        "If-Modified-Since", "Sec-Fetch-Dest", "Sec-Fetch-Mode",
        "Sec-Fetch-Site", "Sec-Ch-Ua", "Sec-Ch-Ua-Mobile",
        "Sec-Ch-Ua-Platform", "DNT", "Upgrade-Insecure-Requests"
    }
    for _, name in ipairs(common_headers) do
        local val = r.headers_in[name]
        if val then
            headers[name:lower()] = val
        end
    end
    return headers
end

--------------------------------------------------------------------------------
-- Backoff Management (Circuit Breaker)
--------------------------------------------------------------------------------

local function get_time_ms()
    local socket = require("socket")
    return socket.gettime() * 1000
end

local function is_in_backoff(r)
    local now = get_time_ms()
    if backoff_state.backoff_until > 0 and now < backoff_state.backoff_until then
        log(r, "debug", "In backoff period, " .. math.floor(backoff_state.backoff_until - now) .. "ms remaining")
        return true
    end
    return false
end

local function record_failure(r)
    backoff_state.failures = backoff_state.failures + 1
    local delay_ms = math.min(
        INITIAL_BACKOFF_MS * (BACKOFF_MULTIPLIER ^ (backoff_state.failures - 1)),
        MAX_BACKOFF_MS
    )
    backoff_state.backoff_until = get_time_ms() + delay_ms
    log(r, "warn", string.format(
        "Validator failure #%d, backing off for %dms",
        backoff_state.failures, delay_ms
    ))
end

local function record_success(r)
    if backoff_state.failures > 0 then
        log(r, "debug", "Validator success, reset backoff")
    end
    backoff_state.failures = 0
    backoff_state.backoff_until = 0
end

--------------------------------------------------------------------------------
-- Path Protection Logic
--------------------------------------------------------------------------------

local function should_protect_path(uri)
    if not uri then return true end

    local uri_lower = uri:lower()

    -- Check exclusion patterns first (skip these paths)
    for _, pattern in ipairs(config.unprotected_paths) do
        if uri_lower:match(pattern) then
            return false
        end
    end

    -- If inclusion patterns are defined, only protect matching paths
    if config.protected_paths and #config.protected_paths > 0 then
        for _, pattern in ipairs(config.protected_paths) do
            if uri_lower:match(pattern) then
                return true
            end
        end
        return false
    end

    -- Default: protect all paths not excluded
    return true
end

--------------------------------------------------------------------------------
-- Persistent HTTP/2 Connection via libcurl
--------------------------------------------------------------------------------

local function get_curl_handle(r)
    if curl_handle then
        return curl_handle
    end

    log(r, "debug", "Creating new curl handle (HTTP/2)")

    local easy = curl.easy()

    -- HTTP/2 over TLS, fall back to HTTP/1.1 if server doesn't support it
    easy:setopt_url(config.validator_url)
    if curl.HTTP_VERSION_2TLS then
        easy:setopt(curl.OPT_HTTP_VERSION, curl.HTTP_VERSION_2TLS)
    end

    -- Timeouts
    easy:setopt_connecttimeout_ms(config.connect_timeout_ms)
    easy:setopt_timeout_ms(config.timeout_ms)

    -- TLS
    if not config.ssl_verify then
        easy:setopt(curl.OPT_SSL_VERIFYPEER, 0)
        easy:setopt(curl.OPT_SSL_VERIFYHOST, 0)
    end

    -- Keep the connection alive between requests
    easy:setopt(curl.OPT_TCP_KEEPALIVE, 1)

    -- POST method
    easy:setopt_post(true)

    curl_handle = easy
    return easy
end

local function destroy_curl_handle()
    if curl_handle then
        pcall(function() curl_handle:close() end)
        curl_handle = nil
    end
end

local function call_validator(r)
    local request_body = {
        url      = get_full_url(r),
        method   = r.method,
        ip       = get_client_ip(r),
        cookie   = extract_centinel_cookie(r),
        referrer = r.headers_in["Referer"] or "",
        headers  = get_request_headers(r),
    }

    local body_json, encode_err = cjson.encode(request_body)
    if not body_json then
        log(r, "error", "Failed to encode request body: " .. tostring(encode_err))
        return nil, "JSON encode error"
    end

    log(r, "debug", "Request body: " .. body_json)
    log(r, "debug", "Calling validator: " .. config.validator_url)

    -- Attempt request, retry once on connection failure (stale socket)
    for attempt = 1, 2 do
        local easy = get_curl_handle(r)

        -- Set request-specific options (URL stays the same across calls)
        easy:setopt_postfields(body_json)
        easy:setopt_httpheader({
            "Content-Type: application/json",
            "x-api-key: " .. (config.secret_key or ""),
            "x-origin-module: apache",
            "x-origin-version: " .. _M._VERSION,
        })

        -- Collect response body
        local response_chunks = {}
        easy:setopt_writefunction(function(data)
            table.insert(response_chunks, data)
            return #data
        end)

        local ok, err = easy:perform()
        if ok then
            local code = easy:getinfo_response_code()
            local body = table.concat(response_chunks)

            -- Log protocol version on first successful request
            if config.debug then
                local proto = easy:getinfo(curl.INFO_PROTOCOL)
                if proto then
                    local proto_name = proto == 2 and "HTTP/2" or "HTTP/1.1"
                    log(r, "debug",
                        "Validator response: " .. proto_name .. " " .. tostring(code) .. " (" .. #body .. " bytes)")
                else
                    log(r, "debug", "Validator response status: " .. tostring(code) .. " (" .. #body .. " bytes)")
                end
            end

            log(r, "debug", "Validator response body: " .. body)
            return { status = code, body = body }, nil
        end

        -- First attempt failed — destroy handle (stale connection) and retry
        local err_msg = tostring(err)
        log(r, "warn", "Validator request failed (attempt " .. attempt .. "): " .. err_msg)
        destroy_curl_handle()

        if attempt == 2 then
            return nil, err_msg
        end
    end
end

--------------------------------------------------------------------------------
-- Response Handling
--------------------------------------------------------------------------------

local function render_html_response(r, status, html, cookies)
    set_response_cookies(r, cookies)
    r.status = status
    r.content_type = "text/html; charset=utf-8"
    r.headers_out["Cache-Control"] = "no-store, no-cache, must-revalidate"
    r:write(html)
    return apache2.DONE
end

local function handle_validator_response(r, res)
    if not res then
        return false, "No response"
    end

    if res.status ~= 200 then
        log(r, "warn", "Validator returned non-200 status: " .. res.status .. " body: " .. (res.body or "nil"))
        return false, "Non-200 status"
    end

    local data, decode_err = cjson.decode(res.body)
    if not data then
        log(r, "error", "Failed to decode validator response: " .. tostring(decode_err))
        return false, "JSON decode error"
    end

    local decision = data.decision
    local request_id = data.request_id or "unknown"
    log(r, "info", "Validator decision: " .. tostring(decision) .. " (request_id: " .. request_id .. ")")

    if decision == "allow" then
        set_response_cookies(r, data.cookies)
        record_success(r)

        if data.response_html then
            local html = base64_decode(data.response_html)
            if html then
                return true, nil, render_html_response(r, data.status_code or 200, html, data.cookies)
            end
        end

        return true, nil
    elseif decision == "block" then
        set_response_cookies(r, data.cookies)
        record_success(r)

        local html = data.response_html and base64_decode(data.response_html)
        local status = data.status_code or 403
        return true, nil, render_html_response(r, status,
            html or
            "<html><body><h1>Access Denied</h1><p>Your request has been blocked by bot protection.</p></body></html>",
            data.cookies)
    elseif decision == "redirect" then
        set_response_cookies(r, data.cookies)
        record_success(r)

        local html = data.response_html and base64_decode(data.response_html)
        local status = data.status_code or 200
        if html then
            return true, nil, render_html_response(r, status, html, data.cookies)
        end
        return true, nil
    elseif decision == "not_matched" then
        record_success(r)

        if data.response_html then
            local html = base64_decode(data.response_html)
            if html then
                return true, nil, render_html_response(r, data.status_code or 200, html, data.cookies)
            end
        end

        return true, nil
    else
        log(r, "warn", "Unknown decision: " .. tostring(decision))
        return false, "Unknown decision"
    end
end

--------------------------------------------------------------------------------
-- Main Access Handler (called by LuaHookAccessChecker)
--------------------------------------------------------------------------------

function _M.access_handler(r)
    -- Check if secret key is configured
    if not config.secret_key or config.secret_key == "" then
        log(r, "error", "CENTINEL_SECRET_KEY not configured")
        return apache2.OK -- Fail open
    end

    -- Check if we're in backoff period
    if is_in_backoff(r) then
        log(r, "debug", "Skipping validation (in backoff)")
        return apache2.OK -- Allow request during backoff
    end

    -- Check if path should be protected
    local uri = r.uri
    if not should_protect_path(uri) then
        log(r, "debug", "Path not protected: " .. uri)
        return apache2.OK -- Allow unprotected paths
    end

    log(r, "debug", "Validating request: " .. uri)

    -- Call validator API
    local res, err = call_validator(r)

    if not res then
        record_failure(r)
        if config.fail_open then
            log(r, "warn", "Validator unavailable, allowing request (fail-open)")
            return apache2.OK
        else
            log(r, "error", "Validator unavailable, blocking request (fail-closed)")
            r.status = 503
            r.content_type = "text/html"
            r:write("<html><body><h1>Service Temporarily Unavailable</h1></body></html>")
            return apache2.DONE
        end
    end

    -- Handle validator response
    local ok, handle_err, result = handle_validator_response(r, res)

    if not ok then
        record_failure(r)
        if config.fail_open then
            log(r, "warn", "Failed to handle validator response: " .. tostring(handle_err) .. ", allowing request")
            return apache2.OK
        else
            r.status = 503
            r.content_type = "text/html"
            r:write("<html><body><h1>Service Temporarily Unavailable</h1></body></html>")
            return apache2.DONE
        end
    end

    -- If result is set, the handler already wrote the response
    if result then
        return result
    end

    -- Allow request to continue to backend
    return apache2.OK
end

--------------------------------------------------------------------------------
-- Initialization
--------------------------------------------------------------------------------

function _M.init(user_config)
    -- Merge user configuration
    if user_config then
        for k, v in pairs(user_config) do
            config[k] = v
        end
    end

    -- Environment variable overrides (read from config since Apache doesn't
    -- expose env vars to Lua directly — pass them via LuaHookAccessChecker wrapper)
    if user_config and user_config.env then
        local env = user_config.env
        if env.CENTINEL_SECRET_KEY and env.CENTINEL_SECRET_KEY ~= "" then
            config.secret_key = env.CENTINEL_SECRET_KEY
        end
        if env.CENTINEL_VALIDATOR_URL and env.CENTINEL_VALIDATOR_URL ~= "" then
            config.validator_url = env.CENTINEL_VALIDATOR_URL
        end
        if env.CENTINEL_DEBUG == "true" or env.CENTINEL_DEBUG == "1" then
            config.debug = true
        end
    end

    -- Destroy any existing curl handle so it picks up new config
    destroy_curl_handle()

    return _M
end

--------------------------------------------------------------------------------
-- Configuration Helpers
--------------------------------------------------------------------------------

function _M.set_protected_paths(paths)
    if type(paths) == "table" then
        config.protected_paths = paths
    end
end

function _M.set_unprotected_paths(paths)
    if type(paths) == "table" then
        config.unprotected_paths = paths
    end
end

function _M.add_unprotected_pattern(pattern)
    table.insert(config.unprotected_paths, pattern)
end

function _M.get_config()
    return config
end

function _M.get_version()
    return _M._VERSION
end

-- Export for testing
_M.should_protect_path = should_protect_path

return _M
