// Part of the Crubit project, under the Apache License v2.0 with LLVM
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

use core::marker::PhantomData;
use core::mem::{self, MaybeUninit};
use core::ptr;

/// A mutually understood ABI for sending bridge types between Rust and C++.
///
/// Bridging values between Rust and C++ is typically done by breaking down values into their
/// primitive, ABI-compatible parts like integers and pointers in the native language. Then, these
/// primitive parts are sent across the language boundary, where the target language can reconstruct
/// the semantically equivalent value. This typically happens by sending the parts as function
/// arguments on an extern C function, but this doesn't work for generic/templated types without
/// monomorphizing each instantiation since C doesn't have templates.
///
/// The solution to this is to transform the value into an ABI-compatible layout that both languages
/// understand, allowing for passing arbitrarily complex types through C as a char pointer to a
/// stack allocated buffer. The [`CrubitAbi`] trait is used to describe this mutually understood
/// ABI.
///
/// # Creating a bridge type
///
/// Let's walk through the example of how `Option<T>` is bridged. The first step is to define an ABI
/// for how it should be bridged, which is represented by a type that implements [`CrubitAbi`]:
///
/// ```rust
/// pub struct OptionAbi<A>(pub A);
///
/// unsafe impl<A: CrubitAbi> CrubitAbi for OptionAbi<A> {
///     type Value = Option<A::Value>;
///
///     // todo
/// }
/// ```
///
/// This is saying "`OptionAbi<A>` is a description of how to bridge an `Option<A::Value>`." We'll
/// get back to the unsafe part later. But before we proceed, we need to decide: what will the
/// Option ABI be? Rust allows for niche optimizations, but to keep things general we'll choose to
/// bridge `Option<T>` as a bool, followed by the value if the bool is true. To express this, we
/// need to implement the other items in the trait:
///
/// ```rust
/// unsafe impl<A: CrubitAbi> CrubitAbi for OptionAbi<A> {
///     type Value = Option<A::Value>;
///
///     const SIZE: usize = mem::size_of::<bool>() + A::SIZE;
///
///     fn encode(self, value: Self::Value, encoder: &mut Encoder) {
///         if let Some(inner) = value {
///             transmute_abi().encode(true, encoder);
///             self.0.encode(inner, encoder);
///         } else {
///             transmute_abi().encode(false, encoder);
///         }
///     }
///
///     unsafe fn decode(self, decoder: &mut Decoder) -> Self::Value {
///         // SAFETY: the caller guarantees that the buffer contains a bool, and if the bool is true,
///         // that the buffer also contains the value.
///         unsafe {
///             if transmute_abi().decode(decoder) {
///                 Some(self.0.decode(decoder))
///             } else {
///                 None
///             }
///         }
///     }
/// }
/// ```
///
/// There are several things going on here. First, we need to define the `SIZE` constant. This
/// information is used to statically compute the size of the buffer required to encode/decode an
/// `Option<T>` with this ABI, allowing us to stack allocate the buffer. Importantly, the current
/// implementation packs all the data with unaligned writes/reads, so alignment information is not
/// needed. Second, we need to define the `encode` and `decode` functions. These functions implement
/// the agreed-upon ABI: bool, optionally followed by the value if the bool is true.
///
/// # Safety
///
/// It's safety critical that the C++ implementation matches the Rust implementation exactly, since
/// the ABI is supposed to be mutually understood.
pub unsafe trait CrubitAbi {
    /// The type that this CrubitAbi is encoding and decoding.
    type Value;

    /// The size in bytes of a `Value` when encoded with this ABI. This is used to statically
    /// compute the size of the buffer required to encode/decode a `Value` with this ABI.
    const SIZE: usize;

    /// Encodes a `Value`, advancing the encoders's position by `SIZE` bytes.
    ///
    /// Aside from implementations for primitives, most implementations of this function will be
    /// composed of other calls to [`Encoder::encode::<A>`], for some `A: CrubitAbi`,
    /// each one advancing the encoder's position by `A::SIZE` bytes. The
    /// implementation should ensure that the these calls do not advance the encoder's position by
    /// more than `SIZE` bytes. This is because the `SIZES` constant is used to compute the buffer
    /// size statically, and if the encoder's position is advanced by more than `SIZE`, the encoder
    /// may panic in debug builds, or cause undefined behavior in release builds.
    ///
    /// # Notes
    ///
    /// The value must be semantically moved into the encoder. This means that if you're
    /// transferring ownership of anything, you must ensure that the original owner leaks the
    /// resource so it can later be reclaimed by decoding. Prefer functions that explicitly leak,
    /// like [`Box::leak`], or defer to [`core::mem::ManuallyDrop`] and [`core::mem::forget`] if
    /// leaking APIs are unavailable.
    ///
    /// # Examples
    ///
    /// ```rust
    /// unsafe impl<A1: CrubitAbi, A2: CrubitAbi> CrubitAbi for (A1, A2) {
    ///     fn encode(self, (a, b): Self::Value, encoder: &mut Encoder) {
    ///         self.0.encode(a);
    ///         self.1.encode(b);
    ///     }
    ///     // other items omitted...
    /// }
    /// ```
    fn encode(self, value: Self::Value, encoder: &mut Encoder);

    /// Decodes a [`Value`], advancing the decoder's position by `SIZE` bytes.
    ///
    /// Aside from implementations for primitives, most implementations of this function will be
    /// composed of other calls to [`Decoder::decode::<A>`], for some `A: CrubitAbi`,
    /// each one advancing the decoder's position by `A::SIZE` bytes. The
    /// implementation should ensure that the these calls do not advance the decoder's position by
    /// more than `SIZE` bytes. This is because the `SIZES` constant is used to compute the buffer
    /// size statically, and if the decoder's position is advanced by more than `SIZE`, the decoder
    /// may panic in debug builds, or cause undefined behavior in release builds.
    ///
    /// # Examples
    ///
    /// ```rust
    /// unsafe impl<A1: CrubitAbi, A2: CrubitAbi> CrubitAbi for (A1, A2) {
    ///     unsafe fn decode(self, decoder: &mut Decoder) -> Self::Value {
    ///         // SAFETY: The caller guarantees that the buffer contains an `A1::Value`, followed
    ///         // by an `A2::Value`, which is the understood ABI for a `(A1, A2)`.
    ///         unsafe {
    ///             let a = self.0.decode();
    ///             let b = self.1.decode();
    ///             (a, b)
    ///         }
    ///     }
    ///     // other items omitted...
    /// }
    /// ```
    ///
    /// # Safety
    ///
    /// The caller guarantees that the buffer's current position contains a `Value` that was
    /// encoded with this ABI from C++.
    unsafe fn decode(self, decoder: &mut Decoder) -> Self::Value;
}

/// A wrapper around a buffer that tracks which parts of a buffer have already been written to.
pub struct Encoder {
    // The number of bytes remaining in the buffer. We write backwards (counting down) so that
    // subtracting too much leads to underflow, which is checked in debug builds.
    remaining_bytes: usize,
    buf: *mut u8,
}

/// A wrapper around a buffer that tracks which parts of a buffer have already been read from.
pub struct Decoder {
    // The number of bytes remaining in the buffer. We read backwards (counting down) so that
    // subtracting too much leads to underflow, which is checked in debug builds.
    remaining_bytes: usize,
    buf: *const u8,
}

/// A [`CrubitAbi`] for encoding a value by transmuting it into the buffer.
pub struct TransmuteAbi<T>(pub PhantomData<T>);

/// Shorthand for constructiong a new [`TransmuteAbi<T>`].
pub fn transmute_abi<T>() -> TransmuteAbi<T> {
    Default::default()
}

impl<T> Default for TransmuteAbi<T> {
    fn default() -> Self {
        TransmuteAbi(PhantomData)
    }
}

impl<T> Clone for TransmuteAbi<T> {
    fn clone(&self) -> Self {
        TransmuteAbi(PhantomData)
    }
}

// Every T can be passed by value.
// SAFETY: The ABI contract for `TransmuteAbi<T>` is that the raw bytes of the value `T` are memcpyd
// into the buffer, padding and all. The idea is that this is only used on types that already have
// a shared ABI between Rust and C++, like primitives and opaque types.
unsafe impl<T> CrubitAbi for TransmuteAbi<T> {
    type Value = T;

    const SIZE: usize = mem::size_of::<Self::Value>();

    fn encode(self, value: Self::Value, encoder: &mut Encoder) {
        // We use the fact that underflow is checked in debug builds to ensure that callers
        // don't overwrite the buffer.
        encoder.remaining_bytes -= Self::SIZE;

        // SAFETY: We have just allocated space to write the value.
        unsafe {
            ptr::write_unaligned(encoder.buf.add(encoder.remaining_bytes).cast::<T>(), value);
        }
    }

    unsafe fn decode(self, decoder: &mut Decoder) -> Self::Value {
        // We use the fact that underflow is checked in debug builds to ensure that callers
        // don't overwrite the buffer.
        decoder.remaining_bytes -= Self::SIZE;

        // SAFETY: The caller guarantees that the buffer contains a T.
        unsafe { ptr::read_unaligned(decoder.buf.add(decoder.remaining_bytes).cast::<T>()) }
    }
}

macro_rules! unsafe_impl_crubit_abi_for_tuple {
    { $( unsafe impl CrubitAbi for ( $($a:ident : $A:ident,)*); )* } => {
        $(
            // SAFETY: The bridge schema for a tuple is the same in C++: each element of the tuple
            // is encoded in order with the corresponding schema.
            unsafe impl<$($A: CrubitAbi),*> CrubitAbi for ($($A,)*) {
                type Value = ( $($A::Value,)* );

                const SIZE: usize = 0 $( + $A::SIZE )*;

                 fn encode(self, ( $($A,)* ): Self::Value, encoder: &mut Encoder) {
                    #![allow(non_snake_case)]
                    #![allow(unused_variables)] // for `encoder` in () case
                    let ($($a,)*) = self;
                    $(
                        $a.encode($A, encoder);
                    )*
                }

                unsafe fn decode(self, decoder: &mut Decoder) -> Self::Value {
                    #![allow(clippy::unused_unit)] // for () case
                    #![allow(unused_variables)] // for `decoder` in () case

                    let ($($a,)*) = self;

                    // SAFETY: the caller guarantees that the buffer contains each element of
                    // the tuple with the correct schema.
                    (
                        $(
                            unsafe { $a.decode(decoder) },
                        )*
                    )
                }
            }
        )*
    }
}

// Every tuple can be passed by bridge. Add more impls here if needed.
// SAFETY: The ABI contract for `(A1, A2, ..., An)` is that the elements of the tuple are encoded in order
// with the corresponding `CrubitAbi`s.
unsafe_impl_crubit_abi_for_tuple! {
    unsafe impl CrubitAbi for ();
    unsafe impl CrubitAbi for (a1: A1,);
    unsafe impl CrubitAbi for (a1: A1, a2: A2,);
    unsafe impl CrubitAbi for (a1: A1, a2: A2, a3: A3,);
    unsafe impl CrubitAbi for (a1: A1, a2: A2, a3: A3, a4: A4,);
    unsafe impl CrubitAbi for (a1: A1, a2: A2, a3: A3, a4: A4, a5: A5,);
    unsafe impl CrubitAbi for (a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6,);
    unsafe impl CrubitAbi for (a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7,);
    unsafe impl CrubitAbi for (a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8,);
    unsafe impl CrubitAbi for (a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9,);
    unsafe impl CrubitAbi for (a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10,);
    unsafe impl CrubitAbi for (a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11,);
    unsafe impl CrubitAbi for (a1: A1, a2: A2, a3: A3, a4: A4, a5: A5, a6: A6, a7: A7, a8: A8, a9: A9, a10: A10, a11: A11, a12: A12,);
}

/// A [`CrubitAbi`] for encoding an `Option` by encoding a bool followed by the value if the bool
/// is true.
#[derive(Clone, Default)]
pub struct OptionAbi<A>(pub A);

// SAFETY: The ABI contract for `OptionAbi<T>` is that the value is encoded as follows:
// bool, optionally followed by the value if the bool is true.
unsafe impl<A: CrubitAbi> CrubitAbi for OptionAbi<A> {
    type Value = Option<A::Value>;

    const SIZE: usize = mem::size_of::<bool>() + A::SIZE;

    fn encode(self, value: Self::Value, encoder: &mut Encoder) {
        if let Some(inner) = value {
            transmute_abi().encode(true, encoder);
            self.0.encode(inner, encoder);
        } else {
            transmute_abi().encode(false, encoder);
        }
    }

    unsafe fn decode(self, decoder: &mut Decoder) -> Self::Value {
        // SAFETY: the caller guarantees that the buffer contains a bool, and if the bool is true,
        // that the buffer also contains the value.
        unsafe {
            if transmute_abi().decode(decoder) {
                Some(self.0.decode(decoder))
            } else {
                None
            }
        }
    }
}

/// Internal functions and types for Crubit generated code.
#[doc(hidden)]
pub mod internal {
    use super::*;

    /// Encodes a value into a buffer.
    ///
    /// This function is only intended to be called by Crubit generated code.
    ///
    /// # Safety
    ///
    /// `buf` must point to a buffer that is large enough to hold the encoded value. The exact size
    /// written can be determined by `<T as CrubitAbi<S>>::SIZE`.
    pub unsafe fn encode<A: CrubitAbi>(crubit_abi: A, buf: *mut u8, value: A::Value) {
        crubit_abi.encode(value, &mut Encoder { remaining_bytes: A::SIZE, buf });
    }

    /// Decodes a value from a buffer.
    ///
    /// This function is only intended to be called by Crubit generated code.
    ///
    /// # Safety
    ///
    /// `buf` must point to a buffer that is at least `<T as CrubitAbi<S>>::SIZE` bytes large, and must
    /// contain a `T` that was encoded with the same schema `S`.
    pub unsafe fn decode<A: CrubitAbi>(crubit_abi: A, buf: *const u8) -> A::Value {
        // SAFETY: The caller guarantees that the buffer contains a `T` that was encoded with schema `S`.
        unsafe { crubit_abi.decode(&mut Decoder { remaining_bytes: A::SIZE, buf }) }
    }

    /// Helper function that returns an empty buffer to reduce noise in the generated code.
    ///
    /// This function is intended to be used by Crubit generated code.
    pub const fn empty_buffer<const BYTES: usize>() -> [MaybeUninit<u8>; BYTES] {
        [const { MaybeUninit::uninit() }; BYTES]
    }
}

// This cannot be a function because it errors with "constant expression depends on a generic
// parameter" when constructing the buffer.
// This macro is unstable, and may be changed. Do not use this unless you have been approved by the
// Crubit team.
#[macro_export]
macro_rules! unstable_encode {
    {@ $crubit_abi_expr:expr, $crubit_abi:ty, $expr:expr} => {{
        let mut __crubit_tmp_buffer = [const { ::core::mem::MaybeUninit::<u8>::uninit() }; <$crubit_abi as $crate::CrubitAbi>::SIZE];
        let __crubit_tmp_value = $expr;
        #[allow(unused_unsafe)]
        unsafe {
            $crate::internal::encode::<$crubit_abi>(
                $crubit_abi_expr,
                __crubit_tmp_buffer.as_mut_ptr() as *mut u8,
                __crubit_tmp_value,
            );
        }
        __crubit_tmp_buffer
    }};
}

// This cannot be a function because it errors with "constant expression depends on a generic
// parameter" when constructing the buffer.
// This macro is unstable, and may be changed. Do not use this unless you have been approved by the
// Crubit team.
#[macro_export]
macro_rules! unstable_return {
    {@ $crubit_abi_expr:expr, $crubit_abi:ty, $cb:expr} => {{
        let mut __crubit_tmp_buffer = [const { ::core::mem::MaybeUninit::<u8>::uninit() }; <$crubit_abi as $crate::CrubitAbi>::SIZE];
        ($cb)(__crubit_tmp_buffer.as_mut_ptr() as *mut u8);
        #[allow(unused_unsafe)]
        unsafe {
            $crate::internal::decode::<$crubit_abi>($crubit_abi_expr, __crubit_tmp_buffer.as_ptr() as *const u8)
        }
    }};
}

#[cfg(test)]
mod tests {
    use super::*;
    use googletest::prelude::*;

    #[gtest]
    fn test_encode_decode_u8_pair() {
        type Abi = (TransmuteAbi<u8>, TransmuteAbi<u8>);

        let original = (1, 2);

        // SAFETY: the buffer contains a T encoded as Abi.
        let value = unsafe {
            internal::decode::<Abi>(
                Abi::default(),
                unstable_encode!(@ (transmute_abi(), transmute_abi()), Abi, original).as_ptr()
                    as *const u8,
            )
        };
        expect_eq!(value, original);
    }

    #[gtest]
    fn test_encode_decode_stuff() {
        type Abi = (
            OptionAbi<(TransmuteAbi<i64>, TransmuteAbi<bool>)>,
            (TransmuteAbi<u8>, TransmuteAbi<f32>),
        );

        let original = (Some((-8, true)), (1, 2.0));

        // SAFETY: the buffer contains a T encoded as Abi.
        let value = unsafe {
            internal::decode::<Abi>(
                Abi::default(),
                unstable_encode!(@
                    (
                        OptionAbi((transmute_abi(), transmute_abi())),
                        (transmute_abi(), transmute_abi())
                    ),
                    Abi,
                    original
                )
                .as_ptr() as *const u8,
            )
        };
        expect_eq!(value, original);
    }
}
