// Copyright (C) 2013-2019 Johan Hake, Jan Blechta and Garth N. Wells
//
// This file is part of DOLFINx (https://www.fenicsproject.org)
//
// SPDX-License-Identifier:    LGPL-3.0-or-later

#include "utils.h"
#include "Constant.h"
#include "DofMap.h"
#include "FiniteElement.h"
#include "Form.h"
#include "Function.h"
#include "FunctionSpace.h"
#include "dofmapbuilder.h"
#include <algorithm>
#include <array>
#include <dolfinx/common/IndexMap.h>
#include <dolfinx/common/Timer.h>
#include <dolfinx/common/log.h>
#include <dolfinx/la/SparsityPattern.h>
#include <dolfinx/mesh/Mesh.h>
#include <dolfinx/mesh/Topology.h>
#include <dolfinx/mesh/topologycomputation.h>
#include <memory>
#include <stdexcept>
#include <string>
#include <ufcx.h>

using namespace dolfinx;

//-----------------------------------------------------------------------------
fem::DofMap fem::create_dofmap(
    MPI_Comm comm, const ElementDofLayout& layout, mesh::Topology& topology,
    const std::function<void(std::span<std::int32_t>, std::uint32_t)>&
        permute_inv,
    const std::function<std::vector<int>(
        const graph::AdjacencyList<std::int32_t>&)>& reorder_fn)
{
  // Create required mesh entities
  const int D = topology.dim();
  for (int d = 0; d < D; ++d)
  {
    if (layout.num_entity_dofs(d) > 0)
      topology.create_entities(d);
  }

  auto [_index_map, bs, dofmaps]
      = build_dofmap_data(comm, topology, {layout}, reorder_fn);
  auto index_map = std::make_shared<common::IndexMap>(std::move(_index_map));

  // If the element's DOF transformations are permutations, permute the
  // DOF numbering on each cell
  if (permute_inv)
  {
    const int num_cells = topology.connectivity(D, 0)->num_nodes();
    topology.create_entity_permutations();
    const std::vector<std::uint32_t>& cell_info
        = topology.get_cell_permutation_info();
    int dim = layout.num_dofs();
    for (std::int32_t cell = 0; cell < num_cells; ++cell)
    {
      std::span<std::int32_t> dofs(dofmaps.front().data() + cell * dim, dim);
      permute_inv(dofs, cell_info[cell]);
    }
  }

  return DofMap(layout, index_map, bs, std::move(dofmaps.front()), bs);
}
//-----------------------------------------------------------------------------
std::vector<fem::DofMap> fem::create_dofmaps(
    MPI_Comm comm, const std::vector<ElementDofLayout>& layouts,
    mesh::Topology& topology,
    const std::function<void(std::span<std::int32_t>, std::uint32_t)>&
        permute_inv,
    const std::function<std::vector<int>(
        const graph::AdjacencyList<std::int32_t>&)>& reorder_fn)
{
  std::int32_t D = topology.dim();
  assert(layouts.size() == topology.entity_types(D).size());

  // Create required mesh entities
  for (std::int32_t d = 0; d < D; ++d)
  {
    if (layouts.front().num_entity_dofs(d) > 0)
      topology.create_entities(d);
  }

  auto [_index_map, bs, dofmaps]
      = build_dofmap_data(comm, topology, layouts, reorder_fn);
  auto index_map = std::make_shared<common::IndexMap>(std::move(_index_map));

  // If the element's DOF transformations are permutations, permute the
  // DOF numbering on each cell
  if (permute_inv)
  {
    if (layouts.size() != 1)
    {
      throw std::runtime_error(
          "DOF transformations not yet supported in mixed topology.");
    }
    std::int32_t num_cells = topology.connectivity(D, 0)->num_nodes();
    topology.create_entity_permutations();
    const std::vector<std::uint32_t>& cell_info
        = topology.get_cell_permutation_info();
    std::int32_t dim = layouts.front().num_dofs();
    for (std::int32_t cell = 0; cell < num_cells; ++cell)
    {
      std::span<std::int32_t> dofs(dofmaps.front().data() + cell * dim, dim);
      permute_inv(dofs, cell_info[cell]);
    }
  }

  std::vector<DofMap> dms;
  dms.reserve(dofmaps.size());
  for (std::size_t i = 0; i < dofmaps.size(); ++i)
    dms.emplace_back(layouts[i], index_map, bs, std::move(dofmaps[i]), bs);

  return dms;
}
//-----------------------------------------------------------------------------
std::vector<std::string> fem::get_coefficient_names(const ufcx_form& ufcx_form)
{
  return std::vector<std::string>(ufcx_form.coefficient_name_map,
                                  ufcx_form.coefficient_name_map
                                      + ufcx_form.num_coefficients);
}
//-----------------------------------------------------------------------------
std::vector<std::string> fem::get_constant_names(const ufcx_form& ufcx_form)
{
  return std::vector<std::string>(ufcx_form.constant_name_map,
                                  ufcx_form.constant_name_map
                                      + ufcx_form.num_constants);
}
//-----------------------------------------------------------------------------
std::vector<std::int32_t>
fem::compute_integration_domains(fem::IntegralType integral_type,
                                 const mesh::Topology& topology,
                                 std::span<const std::int32_t> entities)
{
  const int tdim = topology.dim();

  int dim = -1;
  switch (integral_type)
  {
  case IntegralType::cell:
    dim = tdim;
    break;
  case IntegralType::exterior_facet:
    dim = tdim - 1;
    break;
  case IntegralType::interior_facet:
    dim = tdim - 1;
    break;
  case IntegralType::vertex:
    dim = 0;
    break;
  case IntegralType::ridge:
    dim = tdim - 2;
    break;
  default:
    throw std::runtime_error(
        "Cannot compute integration domains. Integral type not supported.");
  }

  {
    // Create span of the owned entities (leaves off any ghosts)
    assert(topology.index_map(dim));
    auto it1 = std::ranges::lower_bound(entities,
                                        topology.index_map(dim)->size_local());
    entities = entities.first(std::distance(entities.begin(), it1));
  }

  auto get_connectivities = [tdim, &topology](int entity_dim)
      -> std::pair<std::shared_ptr<const graph::AdjacencyList<int>>,
                   std::shared_ptr<const graph::AdjacencyList<int>>>
  {
    auto e_to_c = topology.connectivity(entity_dim, tdim);
    if (!e_to_c)
    {
      throw std::runtime_error(
          std::format("Topology entity-to-cell connectivity has not been "
                      "computed for entity dim {}.",
                      entity_dim));
    }

    auto e_to_f = topology.connectivity(tdim, entity_dim);
    if (!e_to_f)
    {
      throw std::runtime_error(
          std::format("Topology cell-to-entity connectivity has not been "
                      "computed for entity dim {}.",
                      entity_dim));
    }
    return {e_to_c, e_to_f};
  };

  std::vector<std::int32_t> entity_data;
  switch (integral_type)
  {
  case IntegralType::cell:
  {
    entity_data.insert(entity_data.begin(), entities.begin(), entities.end());
    break;
  }
  case IntegralType::interior_facet:
  {
    auto [f_to_c, c_to_f] = get_connectivities(tdim - 1);

    // Create indicator for interprocess facets
    assert(topology.index_map(tdim - 1));
    const std::vector<std::int32_t>& interprocess_facets
        = topology.interprocess_facets();
    std::vector<std::int8_t> interprocess_marker(
        topology.index_map(tdim - 1)->size_local()
            + topology.index_map(tdim - 1)->num_ghosts(),
        0);
    std::ranges::for_each(interprocess_facets, [&interprocess_marker](auto f)
                          { interprocess_marker[f] = 1; });
    for (auto f : entities)
    {
      if (f_to_c->num_links(f) == 2)
      {
        // Get the facet as a pair of (cell, local facet) pairs, one
        // for each cell
        auto facets
            = impl::get_cell_facet_pairs<2>(f, f_to_c->links(f), *c_to_f);
        entity_data.insert(entity_data.end(), facets.begin(), facets.end());
      }
      else if (interprocess_marker[f])
      {
        throw std::runtime_error(
            "Cannot compute interior facet integral over interprocess facet. "
            "Use \"shared facet\"  ghost mode when creating the mesh.");
      }
    }
    break;
  }
  case IntegralType::exterior_facet:
  case IntegralType::vertex:
  case IntegralType::ridge:
  {
    auto [e_to_c, c_to_e] = get_connectivities(dim);
    for (auto entity : entities)
    {
      std::array<std::int32_t, 2> pair = impl::get_cell_entity_pairs<1>(
          entity, e_to_c->links(entity), *c_to_e);
      entity_data.insert(entity_data.end(), pair.begin(), pair.end());
    }
    break;
  }
  default:
    throw std::runtime_error(
        "Cannot compute integration domains. Integral type not supported.");
  }

  return entity_data;
}
//-----------------------------------------------------------------------------
