import React, { useEffect, useState, useMemo, useCallback } from 'react';
import axios from 'axios';
import { useAuthInfo } from "@propelauth/react";
import MDBox from 'components/MDBox';
import { IconButton, Card, CircularProgress, Tooltip } from "@mui/material";
import MDTypography from 'components/MDTypography';
import { MaterialReactTable } from 'material-react-table';
import pymcIcon from 'assets/images/pymc.png';
import robynIcon from 'assets/images/robyn.png';
import lightweightIcon from 'assets/images/lightweight.png';
import olsIcon from 'assets/images/ols.png';
import DownloadIcon from '@mui/icons-material/Download';
import { saveAs } from 'file-saver';
import { blue } from '@mui/material/colors';
import { createTheme, ThemeProvider } from '@mui/material/styles';
import { useParams } from 'react-router-dom';
import { useBrand } from 'layouts/utils/BrandContext';

const theme = createTheme({
  components: {
    MuiTableHead: {
      styleOverrides: {
        root: {
          backgroundColor: '#e0e0e0',
        },
      },
    },
    MuiTableRow: {
      styleOverrides: {
        head: {
          backgroundColor: '#F1F1F1',
          '&.MuiTableRow-head': {
            backgroundColor: '#F1F1F1',
          },
          '&:hover': {
            backgroundColor: '#cccccc',
          },
        },
        root: {
          '&.MuiTableRow-root': {
            backgroundColor: '#ffffff',
          },
          '&:hover': {
            backgroundColor: '#cccccc',
            cursor: 'pointer',
          },
        },
      },
    },
    MuiTableCell: {
      styleOverrides: {
        head: {
          color: '#333',
        },
      },
    },
  },
});

const getModelIcon = (modelType) => {
  switch (modelType?.toLowerCase()) {
    case 'pymc':
      return pymcIcon;
    case 'robyn':
      return robynIcon;
    case 'lightweight':
      return lightweightIcon;
    case 'ols':
      return olsIcon;
    default:
      return null;
  }
};

const DisplayModelResults = React.memo(({ data, columns }) => (
  <ThemeProvider theme={theme}>
    <MaterialReactTable
      columns={columns}
      data={data}
      enableExpanding
      enableRowSelection={false}
      getSubRows={(row) => row.subRows || []}
    />
  </ThemeProvider>
));

const ModelsResultsTab = React.memo(({ refresh }) => {
  const [modelsOutputs, setModelsOutputs] = useState({});
  const [errors, setErrors] = useState({});
  const [loading, setLoading] = useState(true);
  const { loadContext, saveContext, selectedBrand, selectedRefresh, championModel, setChampionModel } = useBrand();
  const { brandName, refreshName, tab } = useParams();
  const auth = useAuthInfo();

  useEffect(() => {
    const fetchContext = async () => {
      if (brandName && refreshName && !selectedBrand && !selectedRefresh) {
        await loadContext(brandName, refreshName);
      }
    };
    fetchContext();
  }, [brandName, refreshName, loadContext]);

  // Fetch model outputs from backend
  const fetchModelsOutputs = useCallback(async (modelsOutputs) => {
    setLoading(true);
    const queryParams = new URLSearchParams({ models_outputs: JSON.stringify(modelsOutputs) }).toString();
    const baseUrl = process.env.REACT_APP_API_BASE_URL;
    const fullUrl = `${baseUrl}/model/display_model_results?${queryParams}`;
    try {
      const response = await axios.post(fullUrl, {}, {
        headers: { "Authorization": `Bearer ${auth.accessToken}`, "Accept": "application/json" }
      });
      if (response.status === 200 && response.data.success) {
        setModelsOutputs(response.data.data);
      } else {
        const errorDetail = response.data.detail || 'Unknown error';
        setErrors(prevState => ({ ...prevState, general: errorDetail }));
      }
    } catch (error) {
      const errorDetail = error.response ? error.response.data.detail : error.message;
      setErrors(prevState => ({ ...prevState, general: errorDetail }));
    } finally {
      setLoading(false);
    }
  }, [auth.accessToken]);

  const calculateModelChampion = useCallback(() => {
    if (!modelsOutputs || Object.keys(modelsOutputs).length === 0) return;
    let champion = null;
    let maxAccuracy = -Infinity;
    let bestRoiValue = null;

    Object.keys(modelsOutputs).forEach(modelType => {
      const modelData = modelsOutputs[modelType];
      if (modelData['Testing Accuracy (%)'] > maxAccuracy) {
        maxAccuracy = modelData['Testing Accuracy (%)'];
        champion = modelData['Model'];
        bestRoiValue = modelData['ROI'];
      }
    });
    const roiValue = parseFloat(bestRoiValue);
    //const outputs_file = modelsOutputs[modelType];


    const updatedChampionModel = {
        model_type: champion,
        highest_accuracy: maxAccuracy,
        best_roi: isNaN(roiValue) ? '' : roiValue.toFixed(2),
        slide_deck_url: null
    };
    setChampionModel(updatedChampionModel);

     const contextData = {
        championModel: updatedChampionModel,
      };
      saveContext(contextData);
    }, [modelsOutputs, setChampionModel, saveContext]);

  useEffect(() => {
    if (refresh && refresh['Models_results']) {
      const modelsOutputs = {};
      refresh['Models_results'].forEach(file => {
        modelsOutputs[file.model_type] = file.file_url;
      });
      fetchModelsOutputs(modelsOutputs);
    }
  }, [refresh, fetchModelsOutputs]);

  useEffect(() => {
    calculateModelChampion();
  }, [modelsOutputs, calculateModelChampion]);

 // Columns for the table
  const columns = useMemo(() => [
    {
      accessorKey: 'Model',
      header: 'Model',
      size: 40,
      Cell: ({ row }) => {
        const modelValue = row.original['Model'] || row.original['Sales Channel'];
        const modelIcon = row.original['Model'] ? getModelIcon(row.original.Model) : null;
        return modelIcon ? (
          <div style={{ display: 'flex', alignItems: 'center', fontSize: '16px', fontWeight: 'bold' }}>
            <img src={modelIcon} alt={row.original.Model} style={{ width: '32px', height: '32px', marginRight: '8px' }} />
             {modelValue?.toUpperCase()}
          </div>
        ) : (
          modelValue || ''
        );
      },
    },
    {
      accessorKey: 'Ad Spend',
      header: 'Ad Spend',
      size: 40,
      Cell: ({ cell }) => {
        const adSpendValue = parseFloat(cell.getValue());
        if (isNaN(adSpendValue)) return '';
        const formattedValue = Math.abs(adSpendValue).toLocaleString(undefined, { minimumFractionDigits: 0, maximumFractionDigits: 0 });
        return adSpendValue < 0 ? `-$${formattedValue}` : `$${formattedValue}`;
      },
    },
    {
      accessorKey: 'Contribution',
      header: 'Contribution',
      size: 40,
      Cell: ({ cell }) => {
        const contributionValue = parseFloat(cell.getValue());
        if (isNaN(contributionValue)) return '';
        const formattedValue = Math.abs(contributionValue).toLocaleString(undefined, { minimumFractionDigits: 0, maximumFractionDigits: 0 });
        return contributionValue < 0 ? `-$${formattedValue}` : `$${formattedValue}`;
      },
    },
    {
      accessorKey: 'Marginal Contribution',
      header: 'Marginal Contribution',
      Cell: ({ cell }) => {
        const marginalContributionValue = parseFloat(cell.getValue());
        return isNaN(marginalContributionValue) ? '' : marginalContributionValue.toLocaleString();
      },
    },
    {
      accessorKey: 'ROI',
      header: 'ROI',
      size: 40,
      Cell: ({ cell }) => {
        const roiValue = parseFloat(cell.getValue());
        return isNaN(roiValue) ? '' : roiValue.toFixed(2);
      },
    },
    {
      accessorKey: 'MROI',
      header: 'MROI',
      size: 40,
      Cell: ({ cell }) => {
        const mroiValue = parseFloat(cell.getValue());
        return isNaN(mroiValue) ? '' : mroiValue.toFixed(2);
      },
    },
    {
      accessorKey: 'Training Accuracy (%)',
      header: 'Training Accuracy (%)',
      size: 40,
      Cell: ({ cell }) => {
        const trainingValue = parseFloat(cell.getValue());
        return isNaN(trainingValue) ? '' : trainingValue.toFixed(2);
      },
    },
    {
      accessorKey: 'Testing Accuracy (%)',
      header: 'Testing Accuracy (%)',
      size: 40,
      Cell: ({ cell }) => {
        const testingValue = parseFloat(cell.getValue());
        return isNaN(testingValue) ? '' : testingValue.toFixed(2);
      },
    },
    {
      accessorKey: 'DW',
      header: 'DW',
      size: 40,
      Cell: ({ cell }) => cell.getValue()
    },
  ], []);

  // Download table data as JSON
  const handleDownloadJson = () => {
    const jsonStr = JSON.stringify(modelsOutputs, null, 2);
    const blob = new Blob([jsonStr], { type: 'application/json' });
    saveAs(blob, 'models_results.json');
  };

  if (loading && refresh) {
    return (
      <MDBox display="flex" justifyContent="center" alignItems="center" height="100vh">
        <CircularProgress sx={{ color: blue[500] }} />
      </MDBox>
    );
  }

  return (
    <MDBox p={1} pb={3}>
      <MDTypography variant="h3">
       Model Results
      <Tooltip title="Download model results file">
        <IconButton onClick={(event) => {
          handleDownloadJson(event);
        }}>
          <DownloadIcon />
        </IconButton>
     </Tooltip>
     </MDTypography>
      <MDBox mb={2}>
        {championModel && (
          <>
            <Card sx={{ mt: 2, mb: 3, p: 3, backgroundColor: '#e8f4fd', boxShadow: '0 4px 8px rgba(0, 0, 0, 0.1)' }}>
                <MDTypography variant="h4" align="left" gutterBottom>
                  Selected Model:  {championModel?.model_type ? championModel?.model_type.toUpperCase() : 'N/A'}
                </MDTypography>
                <MDBox display="flex" justifyContent="space-evenly" alignItems="center">
                  <MDBox textAlign="center">
                    <MDTypography variant="h6">Best ROI</MDTypography>
                    <MDTypography variant="h4">{championModel?.best_roi || 'N/A'}x</MDTypography>
                  </MDBox>
                  <MDBox textAlign="center">
                    <MDTypography variant="h6">Highest Accuracy</MDTypography>
                    <MDTypography variant="h4">{championModel?.highest_accuracy || 'N/A'}%</MDTypography>
                  </MDBox>
                </MDBox>
            </Card>
            <Card sx={{ mt: 2, mb: 3, p: 3, boxShadow: '0 4px 8px rgba(0, 0, 0, 0.1)' }}>
                <MDTypography variant="h4" my={2}>Insights</MDTypography>
                <ul style={{ listStyleType: 'disc', paddingLeft: '20px' }}>
                  <li>
                    <MDTypography variant="body2">
                      <strong>Best Model:</strong> The best model to choose is based on <strong>{championModel.model_type}</strong> because it has the most reasonable in-sample/out-of-sample accuracy.
                    </MDTypography>
                  </li>
                  <li>
                    <MDTypography variant="body2">
                      <strong>Business Results:</strong> Best ROI of <strong>{championModel.best_roi}x</strong> with the highest testing accuracy of <strong>{championModel.highest_accuracy}%</strong>.
                    </MDTypography>
                  </li>
                </ul>
                <MDTypography variant="body2" my={2}>
                  Other models are included for comparison.
                </MDTypography>
             </Card>
          </>
        )}
      </MDBox>
      <Card>
        <MDBox p={3}>
          {errors.general ? (
            <p style={{ color: 'red' }}>Error: {typeof errors.general === 'object' ? JSON.stringify(errors.general) : errors.general}</p>
          ) : modelsOutputs ? (
            <DisplayModelResults data={modelsOutputs} columns={columns} />
          ) : (
            <p>No model results are available. Go to <b>Modeling</b> to run a model.</p>
          )}
        </MDBox>
      </Card>
    </MDBox>
  );
});

export default ModelsResultsTab;
