package gov.cms.grouper.snf.util;

import gov.cms.grouper.snf.SnfTables;
import gov.cms.grouper.snf.model.Assessment;
import gov.cms.grouper.snf.model.SnfDiagnosisCode;
import gov.cms.grouper.snf.model.enums.Rai300;
import gov.cms.grouper.snf.model.table.BasicRow;
import gov.cms.grouper.snf.model.table.PerformanceRecodeRow;
import gov.cms.grouper.snf.r2.logic.nursing.RaiSets;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;

public class ClaimInfo {

  private final int dataVersion;
  private final Map<String, Assessment> byItem;
  private final boolean hasIpa;
  private final Supplier<Integer> functionScoreCache;

  public ClaimInfo(int dataVersion, boolean hasIpa, Map<String, Assessment> assessmentMap) {
    this.dataVersion = dataVersion;
    this.hasIpa = hasIpa;
    byItem = assessmentMap;

    functionScoreCache = new Lazy<>(() -> {
      Set<Rai300> bedMobility = hasIpa() ? RaiSets.BED_MOBILITY_IPA.getSet() : RaiSets.BED_MOBILITY_NON_IPA.getSet();
      Set<Rai300> transfer = hasIpa() ? RaiSets.TRANSFER_IPA.getSet() : RaiSets.TRANSFER_NON_IPA.getSet();
      Set<Rai300> eatToilet = hasIpa() ? RaiSets.EATING_AND_TOILET_IPA.getSet() : RaiSets.EATING_AND_TOILET_NON_IPA.getSet();

      int score = calculateFunctionScore((claim, raiString) -> claim.performanceRecode(() -> claim.getAssessmentValue(Rai300.valueOf(raiString))), bedMobility, transfer,
          Collections.emptySet(), eatToilet);
      return score;
    });
  }

  public static ClaimInfo of(int version, boolean ipa, List<Assessment> assessments) {
    Map<String, Assessment> assessmentMap = assessments.stream().collect(Collectors.toMap(Assessment::getItem, assessment -> assessment));
    return new ClaimInfo(version, ipa, assessmentMap);
  }

  public static ClaimInfo of(int version, boolean ipa, Map<String, Assessment> assessmentMap) {
    return new ClaimInfo(version, ipa, assessmentMap);
  }

  public static List<String> getString(Collection<Rai300> items) {
    List<String> result = items.stream().map(Enum::name).collect(Collectors.toList());
    return result;
  }

  public boolean hasIpa() {
    return hasIpa;
  }

  public int getDataVersion() {
    return dataVersion;
  }

  /**
   * Return an assessment's value if it is available. Otherwise, return Assessment.NULL_VALUE.
   * Iterate through the list of assessments to identify the assessment based on its Rai300 field
   *
   * @return the value of the assessment or -1 if not found
   */
  public int getAssessmentValue(Rai300 field) {
    Assessment foundAssessment = byItem.get(field.name());
    return foundAssessment == null ? Assessment.NULL_VALUE : foundAssessment.getValueInt();
  }

  /**
   * Return the first assessment found within {@code assessments} based on the Rai300 field.
   * Otherwise, return a default Assessment with value Assessment.NULL_VALUE (not checked).
   *
   * @return the first Assessment that is found, or null.
   */
  public Assessment getAssessment(Rai300 field) {
    return byItem.get(field.name());
  }

  /**
   * Given a list of assessment, check if any of the assessments are present. Meaning that if at
   * least one assessment has the value of 1 then this condition returns true.
   *
   * @return true if any rai300s assessment is present within assessments and has value of 1
   */
  public boolean isAnyAssessmentValuesPresent(Collection<Rai300> rai300s) {
    boolean result = rai300s.stream().map(Rai300::name).anyMatch((name) -> {
      boolean rs = byItem.containsKey(name) && byItem.get(name).isCheck();
      return rs;

    });
    return result;
  }

  public boolean isAnyAssessmentValuesGreaterThanN(Collection<Rai300> rai300s, int n) {
    boolean result = rai300s.stream().map(Rai300::name).anyMatch((name) -> {
      boolean rs = byItem.containsKey(name) && byItem.get(name).getValueInt() > n;
      return rs;
    });
    return result;
  }

  /**
   * Given a list of assessment, check if any of the assessments are present. Meaning that if at
   * least one assessment has the value of 1 then this condition returns true.
   *
   * @return true if any rai300s assessment is present within assessments and has value of 1
   */
  public boolean isCheckedAndNotNull(Rai300 rai300) {
    boolean result = isAnyAssessmentValuesPresent(Collections.singletonList(rai300));
    return result;
  }

  /**
   * Count the number of checked Rai300 items within the assessment list
   *
   * @return the number of rai300s that are present within assessments
   */
  public int countAssessmentPresent(Collection<Rai300> rai300s) {
    int result = (int) rai300s.stream().map(Rai300::name).filter((name) -> {
      boolean rs = byItem.get(name) != null && byItem.get(name).isCheck();
      return rs;
    }).count();

    return result;
  }

  public static Set<String> getNtaCategories(Collection<SnfDiagnosisCode> codes) {
    return codes.stream().map(SnfDiagnosisCode::getNtaCategory).collect(Collectors.toSet());
  }

  public boolean hasAssessmentOf(Rai300 field, Predicate<Assessment> condition) {
    Assessment ast = byItem.get(field.name());
    return ast != null && condition.test(ast);
  }

  /**
   * Determine the resident’s cognitive status based on the staff assessment rather than on resident
   * interview
   * <a href="doc-files/mds-3.0-rai-manual-v1.17.1_october_2019.pdf#page=702" class="req">Step3</a>
   *
   * @return if one of the condition exists
   */
  public static boolean isClassifiedBehavioralSymptomsCognitivePerformance(
      Supplier<Integer> b0700Supplier, Supplier<Integer> c0700Supplier,
      Supplier<Integer> c1000Supplier, Supplier<Boolean> isComaAndNoActivities) {
    int b0700 = b0700Supplier.get();
    int c0700 = c0700Supplier.get();
    int c1000 = c1000Supplier.get();

    return meetsCombinedCriteria(b0700,c0700,c1000) || c1000 == 3 || isComaAndNoActivities.get();
  }

  private static boolean meetsCombinedCriteria(int b0700, int c0700, int c1000) {
    int count = 0;
    if(b0700 > 0) count++;
    if(c0700 == 1) count++;
    if(c1000 > 0) count++;
    return count >= 2 && (b0700 >= 2 || c1000 >= 2);
  }

  /**
   * Convert admission performance score to function score for PT/OT and Nursing, based on
   * Performance_Recode.csv
   *
   * @return converted score
   */
  public int performanceRecode(Supplier<Integer> score) {
    PerformanceRecodeRow row = SnfTables.get(SnfTables.performanceRecodeTable, score.get(),
        BasicRow::isVersion, dataVersion);

    if (row != null) {
      return row.getFunctionScore();
    } else {
      return 0;
    }
  }

  public int getFunctionScore() {
    return functionScoreCache.get();
  }

  protected int calculateFunctionScore(BiFunction<ClaimInfo, String, Integer> functionalAssessments,
      Collection<Rai300> bedMobilityList, Collection<Rai300> transferList,
      Collection<Rai300> walkingList, Collection<Rai300> generalItemList) {

    Predicate<Rai300> predicate = (rai) -> getAssessment(rai) != null
        && getAssessment(rai).getValueInt() != Assessment.NULL_VALUE;
    List<Rai300> bedMobility = bedMobilityList.stream().filter(predicate)
        .collect(Collectors.toList());
    List<Rai300> transfer = transferList.stream().filter(predicate).collect(Collectors.toList());
    List<Rai300> walking = walkingList.stream().filter(predicate).collect(Collectors.toList());
    List<Rai300> generalItem = generalItemList.stream().filter(predicate)
        .collect(Collectors.toList());

    return calculateFunctionScoreString(functionalAssessments, ClaimInfo.getString(bedMobility),
        ClaimInfo.getString(transfer), ClaimInfo.getString(walking),
        ClaimInfo.getString(generalItem));
  }

  /**
   * Calculate function score based on the list passed in since different payment component
   * calculate function score based of different items.
   *
   * @return total function score
   */
  public int calculateFunctionScoreString(
      BiFunction<ClaimInfo, String, Integer> functionalAssessments,
      Collection<String> bedMobilityList, Collection<String> transferList,
      Collection<String> walkingList, Collection<String> generalItemList) {

    final int scale = 3;

    BigDecimal avgBedMobility = sumAssessments(bedMobilityList, functionalAssessments)
        .divide(new BigDecimal(2), scale, RoundingMode.HALF_UP);

    BigDecimal avgTransfer = sumAssessments(transferList, functionalAssessments)
        .divide(new BigDecimal(3), scale, RoundingMode.HALF_UP);

    BigDecimal avgWalking = sumAssessments(walkingList, functionalAssessments)
        .divide(new BigDecimal(2), scale, RoundingMode.HALF_UP);

    BigDecimal generalItemSum = sumAssessments(generalItemList, functionalAssessments);

    BigDecimal result = generalItemSum.add(avgBedMobility).add(avgTransfer).add(avgWalking);
    result = result.setScale(0, RoundingMode.HALF_UP);
    return result.intValue();
  }

  private BigDecimal sumAssessments(Collection<String> generalItemList,
      BiFunction<ClaimInfo, String, Integer> functionalAssessments) {
    return generalItemList.stream().map((item) -> functionalAssessments.apply(this, item))
        .map(BigDecimal::valueOf).reduce(BigDecimal.ZERO, BigDecimal::add);
  }

  /**
   * Check for a number of items, if B0100(Coma) is 1 and list of activities ((GG0130A1, GG0130C1,
   * GG0170B1, GG0170C1, GG0170D1, GG0170E1, and GG0170F1) or (GG0130A5, GG0130C5, GG0170B5,
   * GG0170C5, GG0170D5, GG0170E5, and GG0170F5) if IPA) all equal to 1,9, or 88, then return true.
   *
   * @return if coma and no activities at all
   */
  public boolean isComaAndNoActivities() {
    if(getAssessmentValue(Rai300.B0100) != 1) return false;

    Set<Rai300> activitySet = EnumSet.noneOf(Rai300.class);
    activitySet.addAll(hasIpa() ? RaiSets.BED_MOBILITY_IPA.getSet() : RaiSets.BED_MOBILITY_NON_IPA.getSet());
    activitySet.addAll(hasIpa() ? RaiSets.TRANSFER_IPA.getSet() : RaiSets.TRANSFER_NON_IPA.getSet());
    activitySet.addAll(hasIpa() ? RaiSets.EATING_AND_TOILET_IPA.getSet() : RaiSets.EATING_AND_TOILET_NON_IPA.getSet());

    List<Integer> activitiesCheck = Arrays.asList(1, 9, 88);

    // Identify activities that have either 1, 9, or 88 value
    boolean isDependentOrNoActivityOccurred = activitySet.stream().allMatch((item) -> {
      int value = getAssessmentValue(item);
      return activitiesCheck.contains(value);
    });

    return  isDependentOrNoActivityOccurred;
  }

  public Set<String> getAssessmentNames(Predicate<Assessment> checkedAssessment) {
    return byItem.values().stream().filter(checkedAssessment == null ? (ast) -> true : checkedAssessment).map(Assessment::getName)
        .collect(Collectors.toSet());
  }

  public Set<Assessment> getAssessments() {
    return new HashSet<>(byItem.values());
  }

}
