package gov.cms.grouper.snf.r2.logic;

import gov.cms.grouper.snf.SnfContext;
import gov.cms.grouper.snf.SnfTables;
import gov.cms.grouper.snf.model.SnfDiagnosisCode;
import gov.cms.grouper.snf.model.SnfProcessException;
import gov.cms.grouper.snf.model.enums.Rai300;
import gov.cms.grouper.snf.model.table.NtaCmgRow;
import gov.cms.grouper.snf.model.table.NtaComorbidityRow;
import gov.cms.grouper.snf.util.ClaimInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
 * <a href="doc-files/mds-3.0-rai-manual-v1.17.1_october_2019.pdf#page=686" class="req">PDPM
 * Payment
 * Component: NTA</a>
 */
public class NtaLogic extends SnfDataVersionImpl<String> {

  private static final Logger log = LoggerFactory.getLogger(NtaLogic.class);
  // Function accepts condition and return another function that accepts the dataRow.
  public static final Function<String, Predicate<NtaComorbidityRow>> FeedingFilter =
      (condition) -> (dataRow) -> condition.equalsIgnoreCase(dataRow.getConditionService());
  public static final String ICD10_CODE_HIV = "B20";
  private final Set<String> secondaryDxNtaCategories;
  private final Set<String> assessmentNames;
  private final Supplier<Boolean> hasK0510A2;
  private final Supplier<Integer> k0710A2Value;
  private final Supplier<Integer> k0710B2Value;
  private final Predicate<NtaComorbidityRow> parenteralIvFeeding;
  private final Predicate<NtaComorbidityRow> step1$3AdditionalComorbidities;
  private final ClaimInfo claim;
  public NtaLogic(ClaimInfo claim, List<SnfDiagnosisCode> secondaryDiagnosis) {
    super(claim.getDataVersion());

    this.claim = claim;
    this.assessmentNames =
        claim.getAssessmentNames((item) -> item.isCheck() && !item.getItem().equals("M0300D1"));
    // note on page 688
    this.assessmentNames.addAll(claim
        .getAssessmentNames((item) -> item.getValueInt() > 0 && item.getItem().equals("M0300D1")));

    this.secondaryDxNtaCategories = ClaimInfo.getNtaCategories(secondaryDiagnosis.stream()
        .filter(snfDiagnosisCode -> !snfDiagnosisCode.getValue().equals(ICD10_CODE_HIV))
        .collect(Collectors.toList()));

    this.hasK0510A2 = () -> claim.isCheckedAndNotNull(Rai300.K0510A2);
    this.k0710A2Value = () -> claim.getAssessmentValue(Rai300.K0710A2);
    this.k0710B2Value = () -> claim.getAssessmentValue(Rai300.K0710B2);
    this.parenteralIvFeeding =
        step1$2ParenteralIvFeedingCondition(this.hasK0510A2, this.k0710A2Value, this.k0710B2Value);
    this.step1$3AdditionalComorbidities =
        (row) -> step1$3AdditionalComorbidities(this.assessmentNames, row);
  }

  /**
   * Determine whether the resident meets the criteria for the comorbidity: "Parenteral/IV Feeding –
   * High Intensity" or the comorbidity: "Parenteral/IV Feeding – Low Intensity"
   * Return the applicable MDS_ITEM to search for the scores in step 2.
   * TODO should update to override to reference the new MDS item
   */
  public Predicate<NtaComorbidityRow> step1$2ParenteralIvFeedingCondition(
      Supplier<Boolean> hasK0510A2, Supplier<Integer> k0710A2Value,
      Supplier<Integer> k0710B2Value) {
    ParenteralIvFeeding feeding = null;
    // Did resident receive parenteral/IV feeding during the last 7 days?
    if (hasK0510A2.get()) {
      int k0710A2 = k0710A2Value.get();
      if (k0710A2 == 3) { // Parenteral/IV Feeding – High Intensity
        feeding = ParenteralIvFeeding.HighIntensity;
      } else if (k0710A2 == 2 && k0710B2Value.get() == 2) { // Parenteral/IV Feeding – Low Intensity
        feeding = ParenteralIvFeeding.LowIntensity;
      }
    }

    Predicate<NtaComorbidityRow> result = (row) -> Boolean.FALSE;
    if (feeding != null) {
      // Condition to check against the data
      // Accepts condition and returns predicate to check on the data row.
      result = NtaLogic.FeedingFilter.apply(feeding.getConditionService());
    }

    return SnfContext.trace(feeding, result);
  }

  /**
   * Determine whether the resident has any additional NTA-related comorbidities. <a
   * href="doc-files/mds-3.0-rai-manual-v1.17.1_october_2019.pdf#page=686" class= "req">step1.3</a>
   */
  public boolean step1$3AdditionalComorbidities(Iterable<String> assessmentNames,
      NtaComorbidityRow row) {
    boolean result = false;
    for (String name : assessmentNames) {
      if (row.getMdsItems().contains(name.toUpperCase())
          // Do not check again on the Low/High intensity entries (already check in step 1.2)
          && !row.getConditionService().equalsIgnoreCase("Parenteral IV Feeding: Level High")
          && !row.getConditionService().equalsIgnoreCase("Parenteral IV Feeding: Level Low")) {
        result = true;
        SnfContext.trace(name);
        break;
      }
    }
    return result;
  }

  /**
   * Summarize the resident’s total NTA score from previous steps
   * <a href="doc-files/mds-3.0-rai-manual-v1.17.1_october_2019.pdf#page=688" class="req">step
   * 2</a>
   *
   * @return Total NTA score
   */
  public int step2NtaScore(Collection<String> secondaryDxNtaCategories,
      Predicate<NtaComorbidityRow> step1$3AdditionalComorbidities,
      Predicate<NtaComorbidityRow> parenteralIvFeedingFilter) {

    Predicate<NtaComorbidityRow> conditionService = (row) -> secondaryDxNtaCategories.contains(row.getConditionService());

    Predicate<NtaComorbidityRow> conditions = conditionService
        .or(parenteralIvFeedingFilter)  // Step 1.2
        .or(step1$3AdditionalComorbidities);

    conditions =
        conditions.and(
            (row) -> row.isVersion(super.getDataVersion()));

    List<Integer> scores = SnfTables.selectAll(SnfTables.ntaComorbidityTableByConditionOfService,
        conditions, (row) -> {
          SnfContext.trace(row.toString());
          return row.getPoint();
        });
    Integer sum = scores.stream().reduce(0,Integer::sum);

    return SnfContext.trace(sum);

  }

  /**
   * <a href="doc-files/mds-3.0-rai-manual-v1.17.1_october_2019.pdf#page=689" class="req">implement
   * step 3</a>
   *
   * @return NTA Case-Mix Group
   */
  public String step3NtaCaseMixGroup(int ntaScore) {
    NtaCmgRow row = SnfTables.ntaCmgTable.values().stream().flatMap(Collection::stream)
        .filter((item) -> item.isVersion(getDataVersion()) && item.isScore(ntaScore)).findFirst()
        .orElse(null);
    String cmg;
    try {
      cmg = row.getCmg();
    } catch (Throwable th) {
      throw new SnfProcessException("Unable to determine CMG for NTA logic", th);
    }

    return SnfContext.trace(cmg);
  }

  /**
   * <a href="doc-files/mds-3.0-rai-manual-v1.17.1_october_2019.pdf#page=686" class="req">Implement
   * PDPM Payment Component: NTA</a>
   */

  @Override
  public String exec() {
    int ntaScore = this.step2NtaScore(this.secondaryDxNtaCategories, step1$3AdditionalComorbidities,
        parenteralIvFeeding);

    String result = this.step3NtaCaseMixGroup(ntaScore);

    return SnfContext.trace("----------- NTA CMG", result);

  }

  public ClaimInfo getClaim() {
    return claim;
  }

  public enum ParenteralIvFeeding {
    HighIntensity(Arrays.asList(Rai300.K0510A2.name(), Rai300.K0710A2.name()), "Parenteral IV Feeding: Level High"),
    HighIntensity_V220(Arrays.asList(Rai300.K0520A3.name(), Rai300.K0710A2.name()), "Parenteral IV Feeding: Level High"),
    LowIntensity(
        Arrays.asList(Rai300.K0510A2.name(), Rai300.K0710A2.name(), Rai300.K0710B2.name()), "Parenteral IV Feeding: Level Low"),
    LowIntensity_V220(
        Arrays.asList(Rai300.K0520A3.name(), Rai300.K0710A2.name(), Rai300.K0710B2.name()), "Parenteral IV Feeding: Level Low");

    private final SortedSet<String> items;
    private final String conditionService;

    ParenteralIvFeeding(List<String> items, String conditionService) {
      this.items = Collections.unmodifiableSortedSet(new TreeSet<>(items));
      this.conditionService = conditionService;
    }

    public SortedSet<String> getItems() {
      return items;
    }

    public String getConditionService() {
      return conditionService;
    }
  }

}
