//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_CLANG_CIR_DIALECT_BUILDER_CIRBASEBUILDER_H
#define LLVM_CLANG_CIR_DIALECT_BUILDER_CIRBASEBUILDER_H

#include "clang/AST/CharUnits.h"
#include "clang/Basic/AddressSpaces.h"
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/MissingFeatures.h"
#include "llvm/ADT/STLForwardCompat.h"
#include "llvm/Support/ErrorHandling.h"

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"

namespace cir {

enum class OverflowBehavior {
  None = 0,
  NoSignedWrap = 1 << 0,
  NoUnsignedWrap = 1 << 1,
  Saturated = 1 << 2,
};

constexpr OverflowBehavior operator|(OverflowBehavior a, OverflowBehavior b) {
  return static_cast<OverflowBehavior>(llvm::to_underlying(a) |
                                       llvm::to_underlying(b));
}

constexpr OverflowBehavior operator&(OverflowBehavior a, OverflowBehavior b) {
  return static_cast<OverflowBehavior>(llvm::to_underlying(a) &
                                       llvm::to_underlying(b));
}

constexpr OverflowBehavior &operator|=(OverflowBehavior &a,
                                       OverflowBehavior b) {
  a = a | b;
  return a;
}

constexpr OverflowBehavior &operator&=(OverflowBehavior &a,
                                       OverflowBehavior b) {
  a = a & b;
  return a;
}

class CIRBaseBuilderTy : public mlir::OpBuilder {

public:
  CIRBaseBuilderTy(mlir::MLIRContext &mlirContext)
      : mlir::OpBuilder(&mlirContext) {}
  CIRBaseBuilderTy(mlir::OpBuilder &builder) : mlir::OpBuilder(builder) {}

  mlir::Value getConstAPInt(mlir::Location loc, mlir::Type typ,
                            const llvm::APInt &val) {
    return cir::ConstantOp::create(*this, loc, cir::IntAttr::get(typ, val));
  }

  cir::ConstantOp getConstant(mlir::Location loc, mlir::TypedAttr attr) {
    return cir::ConstantOp::create(*this, loc, attr);
  }

  cir::ConstantOp getConstantInt(mlir::Location loc, mlir::Type ty,
                                 int64_t value) {
    return getConstant(loc, cir::IntAttr::get(ty, value));
  }

  mlir::Value getSignedInt(mlir::Location loc, int64_t val, unsigned numBits) {
    auto type = cir::IntType::get(getContext(), numBits, /*isSigned=*/true);
    return getConstAPInt(loc, type,
                         llvm::APInt(numBits, val, /*isSigned=*/true));
  }

  mlir::Value getUnsignedInt(mlir::Location loc, uint64_t val,
                             unsigned numBits) {
    auto type = cir::IntType::get(getContext(), numBits, /*isSigned=*/false);
    return getConstAPInt(loc, type, llvm::APInt(numBits, val));
  }

  // Creates constant null value for integral type ty.
  cir::ConstantOp getNullValue(mlir::Type ty, mlir::Location loc) {
    return getConstant(loc, getZeroInitAttr(ty));
  }

  mlir::TypedAttr getConstNullPtrAttr(mlir::Type t) {
    assert(mlir::isa<cir::PointerType>(t) && "expected cir.ptr");
    return getConstPtrAttr(t, 0);
  }

  mlir::TypedAttr getNullDataMemberAttr(cir::DataMemberType ty) {
    return cir::DataMemberAttr::get(ty);
  }

  mlir::TypedAttr getZeroInitAttr(mlir::Type ty) {
    if (mlir::isa<cir::IntType>(ty))
      return cir::IntAttr::get(ty, 0);
    if (cir::isAnyFloatingPointType(ty))
      return cir::FPAttr::getZero(ty);
    if (auto complexType = mlir::dyn_cast<cir::ComplexType>(ty))
      return cir::ZeroAttr::get(complexType);
    if (auto arrTy = mlir::dyn_cast<cir::ArrayType>(ty))
      return cir::ZeroAttr::get(arrTy);
    if (auto vecTy = mlir::dyn_cast<cir::VectorType>(ty))
      return cir::ZeroAttr::get(vecTy);
    if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(ty))
      return getConstNullPtrAttr(ptrTy);
    if (auto recordTy = mlir::dyn_cast<cir::RecordType>(ty))
      return cir::ZeroAttr::get(recordTy);
    if (auto dataMemberTy = mlir::dyn_cast<cir::DataMemberType>(ty))
      return getNullDataMemberAttr(dataMemberTy);
    if (mlir::isa<cir::BoolType>(ty)) {
      return getFalseAttr();
    }
    llvm_unreachable("Zero initializer for given type is NYI");
  }

  cir::ConstantOp getBool(bool state, mlir::Location loc) {
    return cir::ConstantOp::create(*this, loc, getCIRBoolAttr(state));
  }
  cir::ConstantOp getFalse(mlir::Location loc) { return getBool(false, loc); }
  cir::ConstantOp getTrue(mlir::Location loc) { return getBool(true, loc); }

  cir::BoolType getBoolTy() { return cir::BoolType::get(getContext()); }
  cir::VoidType getVoidTy() { return cir::VoidType::get(getContext()); }

  cir::IntType getUIntNTy(int n) {
    return cir::IntType::get(getContext(), n, false);
  }

  static unsigned getCIRIntOrFloatBitWidth(mlir::Type eltTy) {
    if (auto intType = mlir::dyn_cast<cir::IntTypeInterface>(eltTy))
      return intType.getWidth();
    if (auto floatType = mlir::dyn_cast<cir::FPTypeInterface>(eltTy))
      return floatType.getWidth();

    llvm_unreachable("Unsupported type in getCIRIntOrFloatBitWidth");
  }
  cir::IntType getSIntNTy(int n) {
    return cir::IntType::get(getContext(), n, true);
  }

  cir::PointerType getPointerTo(mlir::Type ty) {
    return cir::PointerType::get(ty);
  }

  cir::PointerType getPointerTo(mlir::Type ty, cir::TargetAddressSpaceAttr as) {
    return cir::PointerType::get(ty, as);
  }

  cir::PointerType getPointerTo(mlir::Type ty, clang::LangAS langAS) {
    if (langAS == clang::LangAS::Default) // Default address space.
      return getPointerTo(ty);

    if (clang::isTargetAddressSpace(langAS)) {
      unsigned addrSpace = clang::toTargetAddressSpace(langAS);
      auto asAttr = cir::TargetAddressSpaceAttr::get(
          getContext(), getUI32IntegerAttr(addrSpace));
      return getPointerTo(ty, asAttr);
    }

    llvm_unreachable("language-specific address spaces NYI");
  }

  cir::PointerType getVoidPtrTy(clang::LangAS langAS = clang::LangAS::Default) {
    return getPointerTo(cir::VoidType::get(getContext()), langAS);
  }

  cir::PointerType getVoidPtrTy(cir::TargetAddressSpaceAttr as) {
    return getPointerTo(cir::VoidType::get(getContext()), as);
  }

  cir::MethodAttr getMethodAttr(cir::MethodType ty, cir::FuncOp methodFuncOp) {
    auto methodFuncSymbolRef = mlir::FlatSymbolRefAttr::get(methodFuncOp);
    return cir::MethodAttr::get(ty, methodFuncSymbolRef);
  }

  cir::MethodAttr getNullMethodAttr(cir::MethodType ty) {
    return cir::MethodAttr::get(ty);
  }

  cir::BoolAttr getCIRBoolAttr(bool state) {
    return cir::BoolAttr::get(getContext(), state);
  }

  cir::BoolAttr getTrueAttr() { return getCIRBoolAttr(true); }
  cir::BoolAttr getFalseAttr() { return getCIRBoolAttr(false); }

  mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real,
                                  mlir::Value imag) {
    auto resultComplexTy = cir::ComplexType::get(real.getType());
    return cir::ComplexCreateOp::create(*this, loc, resultComplexTy, real,
                                        imag);
  }

  mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand) {
    auto resultType = operand.getType();
    if (auto complexResultType = mlir::dyn_cast<cir::ComplexType>(resultType))
      resultType = complexResultType.getElementType();
    return cir::ComplexRealOp::create(*this, loc, resultType, operand);
  }

  mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand) {
    auto resultType = operand.getType();
    if (auto complexResultType = mlir::dyn_cast<cir::ComplexType>(resultType))
      resultType = complexResultType.getElementType();
    return cir::ComplexImagOp::create(*this, loc, resultType, operand);
  }

  cir::LoadOp createLoad(mlir::Location loc, mlir::Value ptr,
                         bool isVolatile = false, uint64_t alignment = 0) {
    mlir::IntegerAttr alignmentAttr = getAlignmentAttr(alignment);
    return cir::LoadOp::create(*this, loc, ptr, /*isDeref=*/false, isVolatile,
                               alignmentAttr, cir::SyncScopeKindAttr{},
                               cir::MemOrderAttr{});
  }

  mlir::Value createAlignedLoad(mlir::Location loc, mlir::Value ptr,
                                uint64_t alignment) {
    return createLoad(loc, ptr, /*isVolatile=*/false, alignment);
  }

  mlir::Value createNot(mlir::Value value) {
    return cir::UnaryOp::create(*this, value.getLoc(), value.getType(),
                                cir::UnaryOpKind::Not, value);
  }

  /// Create a do-while operation.
  cir::DoWhileOp createDoWhile(
      mlir::Location loc,
      llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> condBuilder,
      llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> bodyBuilder) {
    return cir::DoWhileOp::create(*this, loc, condBuilder, bodyBuilder);
  }

  /// Create a while operation.
  cir::WhileOp createWhile(
      mlir::Location loc,
      llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> condBuilder,
      llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> bodyBuilder) {
    return cir::WhileOp::create(*this, loc, condBuilder, bodyBuilder);
  }

  /// Create a for operation.
  cir::ForOp createFor(
      mlir::Location loc,
      llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> condBuilder,
      llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> bodyBuilder,
      llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> stepBuilder) {
    return cir::ForOp::create(*this, loc, condBuilder, bodyBuilder,
                              stepBuilder);
  }

  /// Create a break operation.
  cir::BreakOp createBreak(mlir::Location loc) {
    return cir::BreakOp::create(*this, loc);
  }

  /// Create a continue operation.
  cir::ContinueOp createContinue(mlir::Location loc) {
    return cir::ContinueOp::create(*this, loc);
  }

  mlir::Value createUnaryOp(mlir::Location loc, cir::UnaryOpKind kind,
                            mlir::Value operand) {
    return cir::UnaryOp::create(*this, loc, kind, operand);
  }

  mlir::TypedAttr getConstPtrAttr(mlir::Type type, int64_t value) {
    return cir::ConstPtrAttr::get(type, getI64IntegerAttr(value));
  }

  mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType,
                           mlir::Type type, llvm::StringRef name,
                           mlir::IntegerAttr alignment,
                           mlir::Value dynAllocSize) {
    return cir::AllocaOp::create(*this, loc, addrType, type, name, alignment,
                                 dynAllocSize);
  }

  mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType,
                           mlir::Type type, llvm::StringRef name,
                           clang::CharUnits alignment,
                           mlir::Value dynAllocSize) {
    mlir::IntegerAttr alignmentAttr = getAlignmentAttr(alignment);
    return createAlloca(loc, addrType, type, name, alignmentAttr, dynAllocSize);
  }

  mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType,
                           mlir::Type type, llvm::StringRef name,
                           mlir::IntegerAttr alignment) {
    return cir::AllocaOp::create(*this, loc, addrType, type, name, alignment);
  }

  mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType,
                           mlir::Type type, llvm::StringRef name,
                           clang::CharUnits alignment) {
    mlir::IntegerAttr alignmentAttr = getAlignmentAttr(alignment);
    return createAlloca(loc, addrType, type, name, alignmentAttr);
  }

  /// Get constant address of a global variable as an MLIR attribute.
  /// This wrapper infers the attribute type through the global op.
  cir::GlobalViewAttr getGlobalViewAttr(cir::GlobalOp globalOp,
                                        mlir::ArrayAttr indices = {}) {
    cir::PointerType type = getPointerTo(globalOp.getSymType());
    return getGlobalViewAttr(type, globalOp, indices);
  }

  /// Get constant address of a global variable as an MLIR attribute.
  cir::GlobalViewAttr getGlobalViewAttr(cir::PointerType type,
                                        cir::GlobalOp globalOp,
                                        mlir::ArrayAttr indices = {}) {
    auto symbol = mlir::FlatSymbolRefAttr::get(globalOp.getSymNameAttr());
    return cir::GlobalViewAttr::get(type, symbol, indices);
  }

  mlir::Value createGetGlobal(mlir::Location loc, cir::GlobalOp global,
                              bool threadLocal = false) {
    assert(!cir::MissingFeatures::addressSpace());
    return cir::GetGlobalOp::create(*this, loc,
                                    getPointerTo(global.getSymType()),
                                    global.getSymNameAttr(), threadLocal);
  }

  mlir::Value createGetGlobal(cir::GlobalOp global, bool threadLocal = false) {
    return createGetGlobal(global.getLoc(), global, threadLocal);
  }

  /// Create a copy with inferred length.
  cir::CopyOp createCopy(mlir::Value dst, mlir::Value src,
                         bool isVolatile = false) {
    return cir::CopyOp::create(*this, dst.getLoc(), dst, src, isVolatile);
  }

  cir::StoreOp createStore(mlir::Location loc, mlir::Value val, mlir::Value dst,
                           bool isVolatile = false,
                           mlir::IntegerAttr align = {},
                           cir::SyncScopeKindAttr scope = {},
                           cir::MemOrderAttr order = {}) {
    return cir::StoreOp::create(*this, loc, val, dst, isVolatile, align, scope,
                                order);
  }

  /// Emit a load from an boolean flag variable.
  cir::LoadOp createFlagLoad(mlir::Location loc, mlir::Value addr) {
    mlir::Type boolTy = getBoolTy();
    if (boolTy != mlir::cast<cir::PointerType>(addr.getType()).getPointee())
      addr = createPtrBitcast(addr, boolTy);
    return createLoad(loc, addr, /*isVolatile=*/false, /*alignment=*/1);
  }

  cir::StoreOp createFlagStore(mlir::Location loc, bool val, mlir::Value dst) {
    mlir::Value flag = getBool(val, loc);
    return CIRBaseBuilderTy::createStore(loc, flag, dst);
  }

  [[nodiscard]] cir::GlobalOp createGlobal(mlir::ModuleOp mlirModule,
                                           mlir::Location loc,
                                           mlir::StringRef name,
                                           mlir::Type type, bool isConstant,
                                           cir::GlobalLinkageKind linkage) {
    mlir::OpBuilder::InsertionGuard guard(*this);
    setInsertionPointToStart(mlirModule.getBody());
    return cir::GlobalOp::create(*this, loc, name, type, isConstant, linkage);
  }

  cir::GetMemberOp createGetMember(mlir::Location loc, mlir::Type resultTy,
                                   mlir::Value base, llvm::StringRef name,
                                   unsigned index) {
    return cir::GetMemberOp::create(*this, loc, resultTy, base, name, index);
  }

  mlir::Value createDummyValue(mlir::Location loc, mlir::Type type,
                               clang::CharUnits alignment) {
    mlir::IntegerAttr alignmentAttr = getAlignmentAttr(alignment);
    auto addr = createAlloca(loc, getPointerTo(type), type, {}, alignmentAttr);
    return cir::LoadOp::create(*this, loc, addr, /*isDeref=*/false,
                               /*isVolatile=*/false, alignmentAttr,
                               /*sync_scope=*/{}, /*mem_order=*/{});
  }

  cir::PtrStrideOp createPtrStride(mlir::Location loc, mlir::Value base,
                                   mlir::Value stride) {
    return cir::PtrStrideOp::create(*this, loc, base.getType(), base, stride);
  }

  //===--------------------------------------------------------------------===//
  // Call operators
  //===--------------------------------------------------------------------===//

  cir::CallOp createCallOp(mlir::Location loc, mlir::SymbolRefAttr callee,
                           mlir::Type returnType, mlir::ValueRange operands,
                           llvm::ArrayRef<mlir::NamedAttribute> attrs = {},
                           llvm::ArrayRef<mlir::NamedAttribute> resAttrs = {}) {
    auto op = cir::CallOp::create(*this, loc, callee, returnType, operands);
    op->setAttrs(attrs);

    assert(!cir::MissingFeatures::functionArgumentAttrs());
    // TODO(cir): At one point we'll have to do a similar thing to this for the
    // argument attributes, except for those, there are 1 Dictionary per
    // argument. Since we only have 1 result however, we can just use a single
    // dictionary here, wrapped in an ArrayAttr of 1.
    auto resultDictAttr = mlir::DictionaryAttr::get(getContext(), resAttrs);
    op.setResAttrsAttr(mlir::ArrayAttr::get(getContext(), resultDictAttr));
    return op;
  }

  cir::CallOp createCallOp(mlir::Location loc, cir::FuncOp callee,
                           mlir::ValueRange operands,
                           llvm::ArrayRef<mlir::NamedAttribute> attrs = {},
                           llvm::ArrayRef<mlir::NamedAttribute> resAttrs = {}) {
    return createCallOp(loc, mlir::SymbolRefAttr::get(callee),
                        callee.getFunctionType().getReturnType(), operands,
                        attrs, resAttrs);
  }

  cir::CallOp
  createIndirectCallOp(mlir::Location loc, mlir::Value indirectTarget,
                       cir::FuncType funcType, mlir::ValueRange operands,
                       llvm::ArrayRef<mlir::NamedAttribute> attrs = {},
                       llvm::ArrayRef<mlir::NamedAttribute> resAttrs = {}) {
    llvm::SmallVector<mlir::Value> resOperands{indirectTarget};
    resOperands.append(operands.begin(), operands.end());
    return createCallOp(loc, mlir::SymbolRefAttr(), funcType.getReturnType(),
                        resOperands, attrs, resAttrs);
  }

  cir::CallOp createCallOp(mlir::Location loc, mlir::SymbolRefAttr callee,
                           mlir::ValueRange operands = mlir::ValueRange(),
                           llvm::ArrayRef<mlir::NamedAttribute> attrs = {},
                           llvm::ArrayRef<mlir::NamedAttribute> resAttrs = {}) {
    return createCallOp(loc, callee, cir::VoidType(), operands, attrs,
                        resAttrs);
  }

  //===--------------------------------------------------------------------===//
  // Cast/Conversion Operators
  //===--------------------------------------------------------------------===//

  mlir::Value createCast(mlir::Location loc, cir::CastKind kind,
                         mlir::Value src, mlir::Type newTy) {
    if (newTy == src.getType())
      return src;
    return cir::CastOp::create(*this, loc, newTy, kind, src);
  }

  mlir::Value createCast(cir::CastKind kind, mlir::Value src,
                         mlir::Type newTy) {
    if (newTy == src.getType())
      return src;
    return createCast(src.getLoc(), kind, src, newTy);
  }

  mlir::Value createIntCast(mlir::Value src, mlir::Type newTy) {
    return createCast(cir::CastKind::integral, src, newTy);
  }

  mlir::Value createIntToPtr(mlir::Value src, mlir::Type newTy) {
    return createCast(cir::CastKind::int_to_ptr, src, newTy);
  }

  mlir::Value createPtrToInt(mlir::Value src, mlir::Type newTy) {
    return createCast(cir::CastKind::ptr_to_int, src, newTy);
  }

  mlir::Value createPtrToBoolCast(mlir::Value v) {
    return createCast(cir::CastKind::ptr_to_bool, v, getBoolTy());
  }

  mlir::Value createBoolToInt(mlir::Value src, mlir::Type newTy) {
    return createCast(cir::CastKind::bool_to_int, src, newTy);
  }

  mlir::Value createBitcast(mlir::Value src, mlir::Type newTy) {
    return createCast(cir::CastKind::bitcast, src, newTy);
  }

  mlir::Value createBitcast(mlir::Location loc, mlir::Value src,
                            mlir::Type newTy) {
    return createCast(loc, cir::CastKind::bitcast, src, newTy);
  }

  mlir::Value createPtrBitcast(mlir::Value src, mlir::Type newPointeeTy) {
    assert(mlir::isa<cir::PointerType>(src.getType()) && "expected ptr src");
    return createBitcast(src, getPointerTo(newPointeeTy));
  }

  mlir::Value createPtrIsNull(mlir::Value ptr) {
    mlir::Value nullPtr = getNullPtr(ptr.getType(), ptr.getLoc());
    return createCompare(ptr.getLoc(), cir::CmpOpKind::eq, ptr, nullPtr);
  }

  mlir::Value createAddrSpaceCast(mlir::Location loc, mlir::Value src,
                                  mlir::Type newTy) {
    return createCast(loc, cir::CastKind::address_space, src, newTy);
  }

  mlir::Value createAddrSpaceCast(mlir::Value src, mlir::Type newTy) {
    return createAddrSpaceCast(src.getLoc(), src, newTy);
  }

  //===--------------------------------------------------------------------===//
  // Other Instructions
  //===--------------------------------------------------------------------===//

  mlir::Value createExtractElement(mlir::Location loc, mlir::Value vec,
                                   uint64_t idx) {
    mlir::Value idxVal =
        getConstAPInt(loc, getUIntNTy(64), llvm::APInt(64, idx));
    return cir::VecExtractOp::create(*this, loc, vec, idxVal);
  }

  mlir::Value createInsertElement(mlir::Location loc, mlir::Value vec,
                                  mlir::Value newElt, uint64_t idx) {
    mlir::Value idxVal =
        getConstAPInt(loc, getUIntNTy(64), llvm::APInt(64, idx));
    return cir::VecInsertOp::create(*this, loc, vec, newElt, idxVal);
  }

  //===--------------------------------------------------------------------===//
  // Binary Operators
  //===--------------------------------------------------------------------===//

  mlir::Value createBinop(mlir::Location loc, mlir::Value lhs,
                          cir::BinOpKind kind, mlir::Value rhs) {
    return cir::BinOp::create(*this, loc, lhs.getType(), kind, lhs, rhs);
  }

  mlir::Value createLowBitsSet(mlir::Location loc, unsigned size,
                               unsigned bits) {
    llvm::APInt val = llvm::APInt::getLowBitsSet(size, bits);
    auto type = cir::IntType::get(getContext(), size, /*isSigned=*/false);
    return getConstAPInt(loc, type, val);
  }

  mlir::Value createAnd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs) {
    return createBinop(loc, lhs, cir::BinOpKind::And, rhs);
  }

  mlir::Value createOr(mlir::Location loc, mlir::Value lhs, mlir::Value rhs) {
    return createBinop(loc, lhs, cir::BinOpKind::Or, rhs);
  }

  mlir::Value createSelect(mlir::Location loc, mlir::Value condition,
                           mlir::Value trueValue, mlir::Value falseValue) {
    assert(trueValue.getType() == falseValue.getType() &&
           "trueValue and falseValue should have the same type");
    return cir::SelectOp::create(*this, loc, trueValue.getType(), condition,
                                 trueValue, falseValue);
  }

  mlir::Value createLogicalAnd(mlir::Location loc, mlir::Value lhs,
                               mlir::Value rhs) {
    return createSelect(loc, lhs, rhs, getBool(false, loc));
  }

  mlir::Value createLogicalOr(mlir::Location loc, mlir::Value lhs,
                              mlir::Value rhs) {
    return createSelect(loc, lhs, getBool(true, loc), rhs);
  }

  mlir::Value createMul(mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
                        OverflowBehavior ob = OverflowBehavior::None) {
    auto op = cir::BinOp::create(*this, loc, lhs.getType(), cir::BinOpKind::Mul,
                                 lhs, rhs);
    op.setNoUnsignedWrap(
        llvm::to_underlying(ob & OverflowBehavior::NoUnsignedWrap));
    op.setNoSignedWrap(
        llvm::to_underlying(ob & OverflowBehavior::NoSignedWrap));
    return op;
  }
  mlir::Value createNSWMul(mlir::Location loc, mlir::Value lhs,
                           mlir::Value rhs) {
    return createMul(loc, lhs, rhs, OverflowBehavior::NoSignedWrap);
  }
  mlir::Value createNUWAMul(mlir::Location loc, mlir::Value lhs,
                            mlir::Value rhs) {
    return createMul(loc, lhs, rhs, OverflowBehavior::NoUnsignedWrap);
  }

  mlir::Value createSub(mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
                        OverflowBehavior ob = OverflowBehavior::None) {
    auto op = cir::BinOp::create(*this, loc, lhs.getType(), cir::BinOpKind::Sub,
                                 lhs, rhs);
    op.setNoUnsignedWrap(
        llvm::to_underlying(ob & OverflowBehavior::NoUnsignedWrap));
    op.setNoSignedWrap(
        llvm::to_underlying(ob & OverflowBehavior::NoSignedWrap));
    op.setSaturated(llvm::to_underlying(ob & OverflowBehavior::Saturated));
    return op;
  }

  mlir::Value createNSWSub(mlir::Location loc, mlir::Value lhs,
                           mlir::Value rhs) {
    return createSub(loc, lhs, rhs, OverflowBehavior::NoSignedWrap);
  }

  mlir::Value createNUWSub(mlir::Location loc, mlir::Value lhs,
                           mlir::Value rhs) {
    return createSub(loc, lhs, rhs, OverflowBehavior::NoUnsignedWrap);
  }

  mlir::Value createAdd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
                        OverflowBehavior ob = OverflowBehavior::None) {
    auto op = cir::BinOp::create(*this, loc, lhs.getType(), cir::BinOpKind::Add,
                                 lhs, rhs);
    op.setNoUnsignedWrap(
        llvm::to_underlying(ob & OverflowBehavior::NoUnsignedWrap));
    op.setNoSignedWrap(
        llvm::to_underlying(ob & OverflowBehavior::NoSignedWrap));
    op.setSaturated(llvm::to_underlying(ob & OverflowBehavior::Saturated));
    return op;
  }

  mlir::Value createNSWAdd(mlir::Location loc, mlir::Value lhs,
                           mlir::Value rhs) {
    return createAdd(loc, lhs, rhs, OverflowBehavior::NoSignedWrap);
  }

  mlir::Value createNUWAdd(mlir::Location loc, mlir::Value lhs,
                           mlir::Value rhs) {
    return createAdd(loc, lhs, rhs, OverflowBehavior::NoUnsignedWrap);
  }

  cir::CmpOp createCompare(mlir::Location loc, cir::CmpOpKind kind,
                           mlir::Value lhs, mlir::Value rhs) {
    return cir::CmpOp::create(*this, loc, kind, lhs, rhs);
  }

  cir::VecCmpOp createVecCompare(mlir::Location loc, cir::CmpOpKind kind,
                                 mlir::Value lhs, mlir::Value rhs) {
    VectorType vecCast = mlir::cast<VectorType>(lhs.getType());
    IntType integralTy =
        getSIntNTy(getCIRIntOrFloatBitWidth(vecCast.getElementType()));
    VectorType integralVecTy =
        cir::VectorType::get(integralTy, vecCast.getSize());
    return cir::VecCmpOp::create(*this, loc, integralVecTy, kind, lhs, rhs);
  }

  mlir::Value createIsNaN(mlir::Location loc, mlir::Value operand) {
    return createCompare(loc, cir::CmpOpKind::ne, operand, operand);
  }

  mlir::Value createShift(mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
                          bool isShiftLeft) {
    return cir::ShiftOp::create(*this, loc, lhs.getType(), lhs, rhs,
                                isShiftLeft);
  }

  mlir::Value createShift(mlir::Location loc, mlir::Value lhs,
                          const llvm::APInt &rhs, bool isShiftLeft) {
    return createShift(loc, lhs, getConstAPInt(loc, lhs.getType(), rhs),
                       isShiftLeft);
  }

  mlir::Value createShift(mlir::Location loc, mlir::Value lhs, unsigned bits,
                          bool isShiftLeft) {
    auto width = mlir::dyn_cast<cir::IntType>(lhs.getType()).getWidth();
    auto shift = llvm::APInt(width, bits);
    return createShift(loc, lhs, shift, isShiftLeft);
  }

  mlir::Value createShiftLeft(mlir::Location loc, mlir::Value lhs,
                              unsigned bits) {
    return createShift(loc, lhs, bits, true);
  }

  mlir::Value createShiftRight(mlir::Location loc, mlir::Value lhs,
                               unsigned bits) {
    return createShift(loc, lhs, bits, false);
  }

  mlir::Value createShiftLeft(mlir::Location loc, mlir::Value lhs,
                              mlir::Value rhs) {
    return createShift(loc, lhs, rhs, true);
  }

  mlir::Value createShiftRight(mlir::Location loc, mlir::Value lhs,
                               mlir::Value rhs) {
    return createShift(loc, lhs, rhs, false);
  }

  //
  // Block handling helpers
  // ----------------------
  //
  static OpBuilder::InsertPoint getBestAllocaInsertPoint(mlir::Block *block) {
    auto last =
        std::find_if(block->rbegin(), block->rend(), [](mlir::Operation &op) {
          return mlir::isa<cir::AllocaOp, cir::LabelOp>(&op);
        });

    if (last != block->rend())
      return OpBuilder::InsertPoint(block, ++mlir::Block::iterator(&*last));
    return OpBuilder::InsertPoint(block, block->begin());
  };

  //
  // Alignment and size helpers
  //

  // Note that mlir::IntegerType is used instead of cir::IntType here because we
  // don't need sign information for these to be useful, so keep it simple.

  // For 0 alignment, any overload of `getAlignmentAttr` returns an empty
  // attribute.
  mlir::IntegerAttr getAlignmentAttr(clang::CharUnits alignment) {
    return getAlignmentAttr(alignment.getQuantity());
  }

  mlir::IntegerAttr getAlignmentAttr(llvm::Align alignment) {
    return getAlignmentAttr(alignment.value());
  }

  mlir::IntegerAttr getAlignmentAttr(int64_t alignment) {
    return alignment ? getI64IntegerAttr(alignment) : mlir::IntegerAttr();
  }

  mlir::IntegerAttr getSizeFromCharUnits(clang::CharUnits size) {
    return getI64IntegerAttr(size.getQuantity());
  }

  // Creates constant nullptr for pointer type ty.
  cir::ConstantOp getNullPtr(mlir::Type ty, mlir::Location loc) {
    assert(!cir::MissingFeatures::targetCodeGenInfoGetNullPointer());
    return cir::ConstantOp::create(*this, loc, getConstPtrAttr(ty, 0));
  }

  /// Create a loop condition.
  cir::ConditionOp createCondition(mlir::Value condition) {
    return cir::ConditionOp::create(*this, condition.getLoc(), condition);
  }

  /// Create a yield operation.
  cir::YieldOp createYield(mlir::Location loc, mlir::ValueRange value = {}) {
    return cir::YieldOp::create(*this, loc, value);
  }

  struct GetMethodResults {
    mlir::Value callee;
    mlir::Value adjustedThis;
  };

  GetMethodResults createGetMethod(mlir::Location loc, mlir::Value method,
                                   mlir::Value objectPtr) {
    // Build the callee function type.
    auto methodFuncTy =
        mlir::cast<cir::MethodType>(method.getType()).getMemberFuncTy();
    auto methodFuncInputTypes = methodFuncTy.getInputs();

    auto objectPtrTy = mlir::cast<cir::PointerType>(objectPtr.getType());
    mlir::Type adjustedThisTy = getVoidPtrTy(objectPtrTy.getAddrSpace());

    llvm::SmallVector<mlir::Type> calleeFuncInputTypes{adjustedThisTy};
    calleeFuncInputTypes.insert(calleeFuncInputTypes.end(),
                                methodFuncInputTypes.begin(),
                                methodFuncInputTypes.end());
    cir::FuncType calleeFuncTy =
        methodFuncTy.clone(calleeFuncInputTypes, methodFuncTy.getReturnType());
    // TODO(cir): consider the address space of the callee.
    assert(!cir::MissingFeatures::addressSpace());
    cir::PointerType calleeTy = getPointerTo(calleeFuncTy);

    auto op = cir::GetMethodOp::create(*this, loc, calleeTy, adjustedThisTy,
                                       method, objectPtr);
    return {op.getCallee(), op.getAdjustedThis()};
  }
};

} // namespace cir

#endif
