//
// Syd: rock-solid application kernel
// src/parsers/proc.rs: /proc nom parsers
//
// Copyright (c) 2024, 2025, 2026 Ali Polatel <alip@chesswob.org>
// This file is based in part upon procinfo-rs crate which is:
//   Copyright (c) 2015 The Rust Project Developers
//   SPDX-License-Identifier: MIT
//
// SPDX-License-Identifier: GPL-3.0

// SAFETY: This module has been liberated from unsafe code!
#![forbid(unsafe_code)]

//! /proc Parsers and utility functions.

use btoi::{btoi, btoi_radix};
use libc::mode_t;
use memchr::memrchr;
use nix::{sys::stat::Mode, unistd::Pid};
use nom::{
    branch::alt,
    bytes::complete::{tag, take, take_until, take_while1},
    combinator::{map, map_res, peek},
    error::{Error, ErrorKind},
    multi::fold_many0,
    sequence::{delimited, preceded},
    IResult, Parser,
};
use procfs_core::process::LimitValue;

use crate::{
    path::XPath,
    proc::{Stat, Statm, Status},
    sigset::{sigset_t, SydSigSet},
};

/// Enum to represent different lines in the status file.
#[derive(Copy, Clone, Eq, PartialEq)]
enum StatusLine {
    Umask(Mode),
    Pid(Pid),
    SigPendingThread(SydSigSet),
    SigPendingProcess(SydSigSet),
    SigBlocked(SydSigSet),
    SigIgnored(SydSigSet),
    SigCaught(SydSigSet),
    Skip,
}

/// Parses proc_pid_status(5), extracting the relevant fields.
pub(crate) fn parse_status(input: &[u8]) -> IResult<&[u8], Status> {
    fold_many0(
        alt((
            map(parse_umask, StatusLine::Umask),
            map(parse_tgid, StatusLine::Pid),
            map(parse_sig_pending_thread, StatusLine::SigPendingThread),
            map(parse_sig_pending_process, StatusLine::SigPendingProcess),
            map(parse_sig_blocked, StatusLine::SigBlocked),
            map(parse_sig_ignored, StatusLine::SigIgnored),
            map(parse_sig_caught, StatusLine::SigCaught),
            map(
                delimited(take_until(&b"\n"[..]), tag(&b"\n"[..]), tag(&b""[..])),
                |_| StatusLine::Skip,
            ),
        )),
        Status::default,
        |mut acc, line| {
            match line {
                StatusLine::Umask(umask) => acc.umask = umask,
                StatusLine::Pid(pid) => acc.pid = pid,
                StatusLine::SigPendingThread(set) => acc.sig_pending_thread = set,
                StatusLine::SigPendingProcess(set) => acc.sig_pending_process = set,
                StatusLine::SigBlocked(set) => acc.sig_blocked = set,
                StatusLine::SigIgnored(set) => acc.sig_ignored = set,
                StatusLine::SigCaught(set) => acc.sig_caught = set,
                StatusLine::Skip => {}
            }
            acc
        },
    )
    .parse(input)
}

/// Parses proc_pid_stat(5), extracting the relevant fields.
pub(crate) fn parse_stat(input: &[u8]) -> IResult<&[u8], Stat> {
    let (input, _) = parse_pid(input)?;
    let (input, _) = tag(" ")(input)?;
    let (input, _) = parse_comm(input)?;
    let (input, _) = tag(" ")(input)?;
    let (input, _) = skip_fields(4)(input)?; // Fields 3-6
    let (input, tty_nr) = parse_tty_nr(input)?;
    let (input, _) = tag(" ")(input)?;
    let (input, _) = skip_fields(12)(input)?; // Fields 8-19
    let (input, num_threads) = parse_num_threads(input)?;
    let (input, _) = tag(" ")(input)?;
    let (input, _) = skip_fields(7)(input)?; // Fields 21-27
    let (input, startstack) = parse_startstack(input)?;
    let (input, _) = tag(" ")(input)?;
    let (input, _) = skip_fields(18)(input)?; // Fields 29-46
    let (input, startbrk) = parse_startbrk(input)?;

    Ok((
        input,
        Stat {
            num_threads,
            startbrk,
            startstack,
            tty_nr,
        },
    ))
}

/// Parses proc_pid_statm(5), extracting only the `size` field.
pub(crate) fn parse_statm(input: &[u8]) -> IResult<&[u8], Statm> {
    let (input, size) = parse_u64_decimal(input)?;
    Ok((input, Statm { size }))
}

/// Parses only the Tgid from proc_pid_status(5), skipping everything else.
pub(crate) fn parse_status_tgid(input: &[u8]) -> IResult<&[u8], Pid> {
    preceded(take_until(&b"Tgid:\t"[..]), parse_tgid).parse(input)
}

/// Parses only the Umask from proc_pid_status(5), skipping everything else.
pub(crate) fn parse_status_umask(input: &[u8]) -> IResult<&[u8], Mode> {
    preceded(take_until(&b"Umask:\t"[..]), parse_umask).parse(input)
}

/// Parses only the Pid from /proc/thread-self/fdinfo/<pidfd>, skipping everything else.
pub(crate) fn parse_pidfd_info_pid(input: &[u8]) -> IResult<&[u8], Pid> {
    preceded(
        take_until(&b"Pid:\t"[..]),
        delimited(tag(&b"Pid:\t"[..]), parse_pid, tag(&b"\n"[..])),
    )
    .parse(input)
}

/// Skips a specified number of space-separated fields.
fn skip_fields<'a>(n: usize) -> impl Fn(&'a [u8]) -> IResult<&'a [u8], ()> {
    move |input: &[u8]| {
        let mut current_input = input;
        for _ in 0..n {
            let (i, _) = take_while1(|c| c != b' ')(current_input)?;
            current_input = i;
            let (i, _) = tag(" ")(current_input)?;
            current_input = i;
        }
        Ok((current_input, ()))
    }
}

/// Parses the "comm" field (executable name) from proc_pid_stat(5).
fn parse_comm(input: &[u8]) -> IResult<&[u8], &XPath> {
    const TASK_COMM_LEN: usize = 16;
    let (after_open, _) = tag(&b"("[..]).parse(input)?;

    let window_len = after_open.len().min(TASK_COMM_LEN);
    let (_, window) = peek(take(window_len)).parse(after_open)?;

    let end_index = if let Some(end_index) = memrchr(b')', window) {
        end_index
    } else {
        return Err(nom::Err::Error(Error::new(after_open, ErrorKind::Tag)));
    };

    let (after_comm, comm) = take(end_index)(after_open)?;
    let (rest, _) = tag(&b")"[..])(after_comm)?;

    Ok((rest, XPath::from_bytes(comm)))
}

/// Extract the soft limit for "Max open files" from proc_pid_limits(5).
pub(crate) fn parse_max_open_files(input: &[u8]) -> IResult<&[u8], LimitValue> {
    preceded(
        // skip ahead to the label.
        take_until("Max open files"),
        // tag and skip the label plus following whitespace.
        // parse either number or "unlimited".
        preceded(
            tag("Max open files"),
            preceded(nom::character::complete::space1, parse_limit_value),
        ),
    )
    .parse(input)
}

/// Parse either a numeric limit or the literal `"unlimited"`.
fn parse_limit_value(input: &[u8]) -> IResult<&[u8], LimitValue> {
    alt((
        map(tag("unlimited"), |_| LimitValue::Unlimited),
        map(parse_u64_decimal, LimitValue::Value),
    ))
    .parse(input)
}

/// Parses the "tty_nr" field from proc_pid_stat(5).
fn parse_tty_nr(input: &[u8]) -> IResult<&[u8], i32> {
    parse_i32_decimal(input)
}

/// Parses the "num_threads" field from proc_pid_stat(5).
fn parse_num_threads(input: &[u8]) -> IResult<&[u8], u64> {
    parse_u64_decimal(input)
}

/// Parses the "startstack" field from proc_pid_stat(5).
fn parse_startstack(input: &[u8]) -> IResult<&[u8], u64> {
    parse_u64_decimal(input)
}

/// Parses the "start_brk" field from proc_pid_stat(5).
fn parse_startbrk(input: &[u8]) -> IResult<&[u8], u64> {
    parse_u64_decimal(input)
}

/// Parses the "Umask" field from proc_pid_status(5).
fn parse_umask(input: &[u8]) -> IResult<&[u8], Mode> {
    delimited(tag(&b"Umask:\t"[..]), parse_mode, tag(&b"\n"[..])).parse(input)
}

/// Parses the "SigPnd" field from proc_pid_status(5).
fn parse_sig_pending_thread(input: &[u8]) -> IResult<&[u8], SydSigSet> {
    delimited(
        tag(&b"SigPnd:\t"[..]),
        map_res(
            take_while1(|c: u8| c.is_ascii_hexdigit()),
            |bytes: &[u8]| {
                btoi_radix::<sigset_t>(bytes, 16)
                    .map(SydSigSet::new)
                    .map_err(|_| Error::new(input, ErrorKind::Digit))
            },
        ),
        tag(&b"\n"[..]),
    )
    .parse(input)
}

/// Parses the "ShdPnd" field from proc_pid_status(5).
fn parse_sig_pending_process(input: &[u8]) -> IResult<&[u8], SydSigSet> {
    delimited(
        tag(&b"ShdPnd:\t"[..]),
        map_res(
            take_while1(|c: u8| c.is_ascii_hexdigit()),
            |bytes: &[u8]| {
                btoi_radix::<sigset_t>(bytes, 16)
                    .map(SydSigSet::new)
                    .map_err(|_| Error::new(input, ErrorKind::Digit))
            },
        ),
        tag(&b"\n"[..]),
    )
    .parse(input)
}

/// Parses the "SigBlk" field from proc_pid_status(5).
fn parse_sig_blocked(input: &[u8]) -> IResult<&[u8], SydSigSet> {
    delimited(
        tag(&b"SigBlk:\t"[..]),
        map_res(
            take_while1(|c: u8| c.is_ascii_hexdigit()),
            |bytes: &[u8]| {
                btoi_radix::<sigset_t>(bytes, 16)
                    .map(SydSigSet::new)
                    .map_err(|_| Error::new(input, ErrorKind::Digit))
            },
        ),
        tag(&b"\n"[..]),
    )
    .parse(input)
}

/// Parses the "SigIgn" field from proc_pid_status(5).
fn parse_sig_ignored(input: &[u8]) -> IResult<&[u8], SydSigSet> {
    delimited(
        tag(&b"SigIgn:\t"[..]),
        map_res(
            take_while1(|c: u8| c.is_ascii_hexdigit()),
            |bytes: &[u8]| {
                btoi_radix::<sigset_t>(bytes, 16)
                    .map(SydSigSet::new)
                    .map_err(|_| Error::new(input, ErrorKind::Digit))
            },
        ),
        tag(&b"\n"[..]),
    )
    .parse(input)
}

/// Parses the "SigCgt" field from proc_pid_status(5).
fn parse_sig_caught(input: &[u8]) -> IResult<&[u8], SydSigSet> {
    delimited(
        tag(&b"SigCgt:\t"[..]),
        map_res(
            take_while1(|c: u8| c.is_ascii_hexdigit()),
            |bytes: &[u8]| {
                btoi_radix::<sigset_t>(bytes, 16)
                    .map(SydSigSet::new)
                    .map_err(|_| Error::new(input, ErrorKind::Digit))
            },
        ),
        tag(&b"\n"[..]),
    )
    .parse(input)
}

/// Parses the "Tgid" field from proc_pid_status(5).
fn parse_tgid(input: &[u8]) -> IResult<&[u8], Pid> {
    delimited(tag(&b"Tgid:\t"[..]), parse_pid, tag(&b"\n"[..])).parse(input)
}

/// Parses a `Pid`.
fn parse_pid(input: &[u8]) -> IResult<&[u8], Pid> {
    map(parse_i32_decimal, Pid::from_raw).parse(input)
}

/// Parses a `Mode` in base-8 format.
fn parse_mode(input: &[u8]) -> IResult<&[u8], Mode> {
    map_res(
        take_while1(|c: u8| (b'0'..=b'7').contains(&c)),
        |bytes: &[u8]| {
            btoi_radix::<mode_t>(bytes, 8)
                .map(Mode::from_bits_retain)
                .map_err(|_| Error::new(input, ErrorKind::Digit))
        },
    )
    .parse(input)
}

/// Parses a `u64`.
fn parse_u64_decimal(input: &[u8]) -> IResult<&[u8], u64> {
    map_res(take_while1(|c: u8| c.is_ascii_digit()), |bytes: &[u8]| {
        btoi::<u64>(bytes).map_err(|_| Error::new(input, ErrorKind::Digit))
    })
    .parse(input)
}

/// Parses a `i32`.
fn parse_i32_decimal(input: &[u8]) -> IResult<&[u8], i32> {
    map_res(take_while1(|c: u8| c.is_ascii_digit()), |bytes: &[u8]| {
        btoi::<i32>(bytes).map_err(|_| Error::new(input, ErrorKind::Digit))
    })
    .parse(input)
}

#[cfg(test)]
mod tests {
    use super::*;

    /*
     * parse_comm test cases
     */
    // (input, should_parse, expected_comm, expected_rest_prefix)
    type Case = (&'static [u8], bool, &'static [u8], &'static [u8]);

    static CASES: &[Case] = &[
        // simple
        (b"(bash) R 1 2 3 ", true, b"bash", b" R "),
        (b"(init) S 1 2 3 ", true, b"init", b" S "),
        (b"(a) R 0 0 0 ", true, b"a", b" R "),
        (b"() R 1 2 3 ", true, b"", b" R "),
        (b"( ) R 1 2 3 ", true, b" ", b" R "),
        // spaces
        (b"(my app) R 1 2 3 ", true, b"my app", b" R "),
        (b"( a  b  ) R 1 2 3 ", true, b" a  b  ", b" R "),
        (b"(tab\tname) S 1 2 3 ", true, b"tab\tname", b" S "),
        // embedded ')' cases
        (b"(lol) hey) R 1 2 3 ", true, b"lol) hey", b" R "),
        (b"(a)b)c) R 1 2 3 ", true, b"a)b)c", b" R "),
        (b"((())) ) R 1 2 3 ", true, b"(())) ", b" R "),
        (b"(()))))  ) R 1 2 3 ", true, b"()))))  ", b" R "),
        (b"(par)en)ted) R 1 2 3 ", true, b"par)en)ted", b" R "),
        // spoof attempts
        (
            b"(lol) R 12) R 2122981 2123483 ",
            true,
            b"lol) R 12",
            b" R ",
        ),
        (b"(foo) S 999) S 1 2 3 ", true, b"foo) S 999", b" S "),
        (b"(x) 999) X 2 3 4 ", true, b"x) 999", b" X "),
        // names made entirely of ')'
        (
            b"()))))))))))))))) R 1 2 3 ",
            true,
            b")))))))))))))))",
            b" R ",
        ),
        // TASK_COMM_LEN boundary (15 bytes name allowed)
        (
            b"(1234567890abcde) R 1 2 3 ",
            true,
            b"1234567890abcde",
            b" R ",
        ), // 15 bytes
        (
            b"(aaaaaaaaaaaaaaa) R 1 2 3 ",
            true,
            b"aaaaaaaaaaaaaaa",
            b" R ",
        ), // 15 bytes
        // 16 bytes before ')' -> should be rejected (no ')' within first 16 bytes after '(')
        (b"(aaaaaaaaaaaaaaaa) R 1 2 3 ", false, b"", b""),
        // UTF-8 inside first 15 bytes
        (
            b"(\xF0\x9F\x98\x80a\xF0\x9F\x98\x80b) R 1 2 3 ",
            true,
            b"\xF0\x9F\x98\x80a\xF0\x9F\x98\x80b",
            b" R ",
        ),
        (
            b"(\xE2\x98\x83\xE2\x98\x83\xE2\x98\x83) R 1 2 3 ",
            true,
            b"\xE2\x98\x83\xE2\x98\x83\xE2\x98\x83",
            b" R ",
        ),
        // minimal trailer
        (b"(ok) R ", true, b"ok", b" R"),
        // many parens/spaces
        (b"(()()) ) R 1 2 3 ", true, b"()()) ", b" R "),
        (b"(()()())) ) R 1 2 3 ", true, b"()()())) ", b" R "),
        (b"(a) ) ) ) ) R 1 2 3 ", true, b"a) ) ) ) ", b" R "),
        // digits/spoof inside comm
        (b"(123) 456) R 1 2 3 ", true, b"123) 456", b" R "),
        (
            b"(statelike) R12) R 1 2 3 ",
            true,
            b"statelike) R12",
            b" R ",
        ),
        // edge embedded cases
        (b"())()()) ) R 1 2 3 ", true, b"))()()) ", b" R "),
        (b"()()()()()()() R 1 2 3 ", true, b")()()()()()(", b" R "),
        // truncated / malformed (should error)
        (b"(no close R 1 2 3 ", false, b"", b""),
        (b"no-open-paren) R 1 2 3 ", false, b"", b""),
        (b"(", false, b"", b""),
        (b"(aaaaaaaaaaaaaa", false, b"", b""), // truncated without ')'
        (b"(a", false, b"", b""),
        // NUL inside the comm (counts towards bytes)
        (b"(nul\0in) R 1 2 3 ", true, b"nul\0in", b" R "),
        // additional adversarial mixes (still within 15 bytes visible)
        (
            b"(()()(()))(())) ) R 1 2 3 ",
            true,
            b"()()(()))(())) ",
            b" R ",
        ),
        (
            b"(()))(()))(())) ) R 1 2 3 ",
            true,
            b"()))(()))(())) ",
            b" R ",
        ),
        (b"(a)b)c)d)e) f) R 1 2 3 ", true, b"a)b)c)d)e) f", b" R "),
        (
            b"()))))))))))))) ) R 1 2 3 ",
            true,
            b")))))))))))))) ",
            b" R ",
        ),
        (
            b"(()(()(()))))) ) R 1 2 3 ",
            true,
            b"()(()(()))))) ",
            b" R ",
        ),
        (b"(prefix) S  ", true, b"prefix", b" S  "),
        (b"(tricky)R 1 2 3 ", true, b"tricky", b"R 1 "),
        (
            b"(123456789012345) R 1 2 3 ",
            true,
            b"123456789012345",
            b" R ",
        ),
        (b"(1234567890123456) R 1 2 3 ", false, b"", b""), // 16 bytes before ')'
        (b"( trailing ) T 1 2 3 ", true, b" trailing ", b" T "),
    ];

    #[test]
    fn proc_test_comm() {
        for (idx, case) in CASES.iter().enumerate() {
            let (input, should_parse, want, want_rest_prefix) = *case;
            let got = parse_comm(input);
            if should_parse {
                match got {
                    Ok((rest, comm)) => {
                        assert_eq!(
                            comm,
                            XPath::from_bytes(want),
                            "case {}: comm mismatch; input=`{}'; want=`{}'; got=`{comm}'",
                            idx + 1,
                            XPath::from_bytes(input),
                            XPath::from_bytes(want),
                        );
                        assert!(
                            rest.starts_with(want_rest_prefix),
                            "case {}: rest prefix mismatch; rest={:?}; want_prefix={:?}; input={:?}",
                            idx + 1,
                            rest,
                            want_rest_prefix,
                            input
                        );
                    }
                    Err(e) => {
                        panic!(
                            "case {}: expected Ok but got Err({:?}); input={:?}",
                            idx, e, input
                        );
                    }
                }
            } else {
                assert!(
                    got.is_err(),
                    "case {}: expected Err but got Ok; input=`{}'; parsed={:?}",
                    idx + 1,
                    XPath::from_bytes(input),
                    got.map(|(rest, got)| (XPath::from_bytes(rest), got))
                );
            }
        }
    }
}
