#pragma once #include "common.h" #include "core/ref.h" #include // hash #include #include namespace nnet { template using Ref = infini::Ref; template Ref make_ref(Params &&...params) { return infini::make_ref(std::forward(params)...); } template > * = nullptr> Ref as(const Ref &ref) { return infini::as(ref); } // Comparator for Ref template struct is_ref : std::false_type {}; template struct is_ref> : std::true_type {}; template typename std::enable_if_t>::value, bool> __ref_less(const Tuple &lhs, const Tuple &rhs) { if constexpr (index >= std::tuple_size>::value - 1) return std::get(lhs) < std::get(rhs); else { if (std::get(lhs) != std::get(rhs)) return std::get(lhs) < std::get(rhs); else return __ref_less(lhs, rhs); } } template typename std::enable_if_t>::value and not address_based, bool> __ref_less(const Tuple &lhs, const Tuple &rhs) { if constexpr (index >= std::tuple_size>::value - 1) return std::get(lhs)->less(std::get(rhs)); else { if (std::get(lhs)->neq(std::get(rhs))) return std::get(lhs)->less(std::get(rhs)); else return __ref_less(lhs, rhs); } } template typename std::enable_if_t< is_ref>::value and address_based, bool> __ref_less(const Tuple &lhs, const Tuple &rhs) { if constexpr (index >= std::tuple_size>::value - 1) return std::get(lhs).get() < std::get(rhs).get(); else { if (std::get(lhs).get() != std::get(rhs).get()) return std::get(lhs).get() < std::get(rhs).get(); else return __ref_less(lhs, rhs); } } template bool ref_addr_less(const Tuple &lhs, const Tuple &rhs) { return __ref_less(lhs, rhs); } template bool ref_value_less(const Tuple &lhs, const Tuple &rhs) { return __ref_less(lhs, rhs); } template class RefAddrLess { public: bool operator()(const Tuple &a, const Tuple &b) const { return ref_addr_less(a, b); } }; template class RefValueLess { public: bool operator()(const Tuple &a, const Tuple &b) const { return ref_value_less(a, b); } }; // make_ref_from_tuple template constexpr Ref<_Tp> make_ref_from_tuple_impl(_Tuple &&__t, std::index_sequence<_Idx...>) { return make_ref<_Tp>(std::get<_Idx>(std::forward<_Tuple>(__t))...); } template constexpr Ref<_Tp> make_ref_from_tuple(_Tuple &&__t) { return make_ref_from_tuple_impl<_Tp>( std::forward<_Tuple>(__t), std::make_index_sequence>>{}); } } // namespace nnet // namespace std { // template struct hash> { // hash hash_; // size_t operator()(const ir::Ref &ref) const { return hash_(ref.get()); // } // }; // } // namespace nnet