import React, { useEffect, useState, useMemo, useCallback } from 'react';
import axios from 'axios';
import { useAuthInfo } from "@propelauth/react";
import MDBox from 'components/MDBox';
import MDTypography from 'components/MDTypography';
import { Card, CircularProgress } from '@mui/material';
import Plot from 'react-plotly.js';
import { blue } from '@mui/material/colors';
import { MaterialReactTable } from 'material-react-table';
import { createTheme, ThemeProvider } from '@mui/material/styles';
import { useBrand } from 'layouts/utils/BrandContext';

const theme = createTheme({
  components: {
    MuiTableHead: {
      styleOverrides: {
        root: {
          backgroundColor: '#e0e0e0',
        },
      },
    },
    MuiTableRow: {
      styleOverrides: {
        head: {
          backgroundColor: '#F1F1F1',
        },
        root: {
          backgroundColor: '#ffffff',
          '&:hover': {
            backgroundColor: '#cccccc',
            cursor: 'pointer',
          },
        },
      },
    },
    MuiTableCell: {
      styleOverrides: {
        head: {
          color: '#333',
        },
      },
    },
  },
});

const formatCurrency = (value) => {
  const formattedValue = Math.abs(value).toLocaleString(undefined, { minimumFractionDigits: 2, maximumFractionDigits: 2 });
  return value < 0 ? `-$${formattedValue}` : `$${formattedValue}`;
};

const ModelStatisticsTab = React.memo(({ refresh }) => {
    const [errors, setErrors] = useState({});
    const [modelsOutputs, setModelsOutputs] = useState({});
    const { championModel, setChampionModel } = useBrand();
    const [loading, setLoading] = useState(true);
    const auth = useAuthInfo();

    // Fetch model outputs
    const fetchModelsOutputs = useCallback(async (modelsOutputs) => {
        setLoading(true);
        const model_type = championModel?.model_type || '';
        const queryParams = new URLSearchParams({
          models_outputs: JSON.stringify(modelsOutputs),
          champion_model_type: model_type
        }).toString();
        const baseUrl = process.env.REACT_APP_API_BASE_URL;
        const fullUrl = `${baseUrl}/model/display_model_statistics?${queryParams}`;
        try {
          const response = await axios.post(
            fullUrl, {},
            {
              headers: {
                "Authorization": `Bearer ${auth.accessToken}`,
                "Accept": "application/json"
              }
            }
          );
          if (response.status === 200 && response.data.success) {
            setModelsOutputs(response.data.data);
            if (response.data.data['champion_model_data'] && Object.keys(response.data.data['champion_model_data']).length > 0) {
              setChampionModel(response.data.data['champion_model_data']);
            }
          } else {
            const errorDetail = response.data.detail || 'Unknown error';
            setErrors(prevState => ({ ...prevState, general: errorDetail }));
          }
        } catch (error) {
          const errorDetail = error.response ? error.response.data.detail : error.message;
          setErrors(prevState => ({ ...prevState, general: errorDetail }));
        } finally {
          setLoading(false);
        }
    }, [auth.accessToken]);

    useEffect(() => {
        if (refresh && refresh['Models_results']) {
          const modelsOutputs = {};
          refresh['Models_results'].forEach(file => {
            modelsOutputs[file.model_type] = file.file_url;
          });
          fetchModelsOutputs(modelsOutputs);
        }
    }, [refresh, fetchModelsOutputs]);

    // Dynamically determine prediction column structure based on data
    const getPredictionColumns = useCallback(() => {
        if (!modelsOutputs?.models) return [];

        const firstModel = modelsOutputs.models[Object.keys(modelsOutputs.models)[0]];
        const firstPrediction = firstModel?.predictions?.[0] || {};

        return Object.keys(firstPrediction).map((key) => ({
            accessorKey: key,
            header: key.replace(/_/g, ' '), // Format header text
            Cell: ({ cell }) => {
                const value = cell.getValue();
                return typeof value === 'number' ? formatCurrency(value) : value;
            },
        }));
    }, [modelsOutputs]);

    const memoizedPlotData = useMemo(() => {
      if (!modelsOutputs?.models) {
        return [];
      }
      return Object.keys(modelsOutputs.models).map((subModelName) => {
        const firstPrediction = modelsOutputs.models[subModelName].predictions?.[0] || {};

        // Find the key for the actual values dynamically
        const actualKey = Object.keys(firstPrediction).find(key => key !== 'Date' && key !== 'Predicted');

        return {
          subModelName,
          plotData: [
            {
              x: modelsOutputs.models[subModelName].predictions.map(row => row.Date),
              y: modelsOutputs.models[subModelName].predictions.map(row => row[actualKey]), // Dynamically select the correct key for Actual
              mode: 'lines',
              name: 'Actual',
              line: { color: 'blue' },
            },
            {
              x: modelsOutputs.models[subModelName].predictions.map(row => row.Date),
              y: modelsOutputs.models[subModelName].predictions.map(row => row.Predicted),
              mode: 'lines',
              name: 'Predicted',
              line: { color: 'red' },
            },
          ],
          layout: {
            title: `Actual vs Predicted for ${subModelName}`,
            xaxis: { title: 'Date' },
            yaxis: { title: 'Revenue' },
            height: 600,
            width: 1000,
            showlegend: true,
          },
          config: {
            responsive: true,
          },
        };
      });
    }, [modelsOutputs]);

  const { models } = modelsOutputs;

  if (loading && refresh) {
    return (
      <MDBox display="flex" justifyContent="center" alignItems="center" height="100vh">
        <CircularProgress sx={{ color: blue[500] }} />
      </MDBox>
    );
  }

  return (
    <MDBox p={3} pb={3} pr={5}>
      <MDTypography variant="h3">Model Statistics</MDTypography>
      <br />
      <Card>
        <MDBox p={3} pb={3} pr={5}>
          {memoizedPlotData.map(({ subModelName, plotData }) => (
            <React.Fragment key={subModelName}>
              <MDTypography variant="h4">{subModelName} - Actual vs Predicted</MDTypography>
              {models[subModelName].predictions && models[subModelName].predictions.length > 0 && (
                <Card sx={{ mb: 3, p: 3 }}>
                  <Plot
                    data={plotData}
                    layout={plotData.layout} // Use memoized layout
                    config={plotData.config} // Use memoized config
                  />
                </Card>
              )}
              <MDTypography variant="h4">Predictions Table</MDTypography>
              <Card sx={{ mb: 3, p: 3 }}>
                <ThemeProvider theme={theme}>
                  <MaterialReactTable
                    columns={getPredictionColumns()} // Dynamically generated columns
                    data={modelsOutputs.models[subModelName].predictions || []}
                  />
                </ThemeProvider>
              </Card>
            </React.Fragment>
          ))}
        </MDBox>
      </Card>
    </MDBox>
  );
});

export default ModelStatisticsTab;
