import React, { useEffect, useState } from 'react';
import Chart from 'react-apexcharts';
import {
    DetectionModel,
    DetectionModelTrainingAccuracyStatisticMetric,
    DetectionModelTrainingAccuracyStatisticMetricName,
    DetectionModelTrainingAccuracyStatisticSeries,
    DetectionModelTrainingAccuracyStatisticSeriesGroup,
    ProjectType
} from '../../../models/api';
import { useStores } from '../../../store';
import { observer } from 'mobx-react-lite';
import { ApexOptions } from 'apexcharts';
import DnaLoader from '../../../components/DnaLoader';
import { TrainingAccuracyLegend } from './TrainingAccuracyLegend';

interface Props {
    detectionModel?: DetectionModel;
}

const TOTAL_LOSS_SLUG = 'Total Loss';
const MAP_DETECTION_SLUG = 'fscore';
const MAP_SEGMENTATION_SLUG = 'map';
const FSCORE_SEGMENTATION_SLUG = 'mAP *';

const baseMetrics = [
    DetectionModelTrainingAccuracyStatisticMetricName.TotalLoss,
    DetectionModelTrainingAccuracyStatisticMetricName.Accuracy,
    DetectionModelTrainingAccuracyStatisticMetricName.Map,
    DetectionModelTrainingAccuracyStatisticMetricName.Map50,
    DetectionModelTrainingAccuracyStatisticMetricName.MapSegmentation,
    DetectionModelTrainingAccuracyStatisticMetricName.FScoreSegmentation
];

const fScoreMetrics = [
    DetectionModelTrainingAccuracyStatisticMetricName.FScoreLarge,
    DetectionModelTrainingAccuracyStatisticMetricName.FScoreMedium,
    DetectionModelTrainingAccuracyStatisticMetricName.FScoreSmall
];

const chartColors = [
    '#353ed1',
    '#dc3545',
    '#fd7e1457',
    '#ffc10745',
    '#28a74545',
    '#330000',
    '#00FF00',
    '#9C27B0'
];

const TrainingAccuracyStatistic: React.FC<Props> = observer(({ detectionModel }) => {
    const { detectionModelsStore, projectsStore } = useStores();
    const [chartIsLoading, setChartIsLoading] = useState<boolean>(true);
    const [seriesGroups, setSeriesGroups] = useState<
        Record<string, DetectionModelTrainingAccuracyStatisticSeries[]>
    >({});
    const [seriesToggleStatuses, setSeriesToggleStatuses] = useState<
        Record<string, boolean>
    >({});
    const [chartData, setChartData] = useState<{
        series: DetectionModelTrainingAccuracyStatisticMetric[],
        options: ApexOptions
    }>();

    const projectType = detectionModel?.project?.type ?? projectsStore.current?.type;
    const isObjectDetection = projectType === ProjectType.ObjectDetection;
    const isObjectSegmentation = projectType === ProjectType.ObjectSegmentation;
    const chartId = detectionModel?.id;

    useEffect(() => {
        if (!detectionModel) {
            return;
        }
        detectionModelsStore.fetchItemTrainingAccuracyStatistic(detectionModel.id);
    }, [detectionModel, detectionModelsStore]);

    useEffect(() => {
        if (!detectionModelsStore.itemTrainingAccuracyStatistic.data) {
            return;
        }

        const metrics = detectionModelsStore.itemTrainingAccuracyStatistic.data.metricsByName;
        let series = Object.values(metrics);

        series = series.map((s) => {
            if (s.name === TOTAL_LOSS_SLUG) {
                return { ...s, name: DetectionModelTrainingAccuracyStatisticMetricName.TotalLoss };
            }
            return s;
        });

        if (isObjectDetection) {
            series = series.map((s) => {
                if (s.name === MAP_DETECTION_SLUG) {
                    return { ...s, name: DetectionModelTrainingAccuracyStatisticMetricName.Map };
                }
                return s;
            });
        } else if (isObjectSegmentation) {
            series = series.map((s) => {
                if (s.name === MAP_SEGMENTATION_SLUG) {
                    return { ...s, name: DetectionModelTrainingAccuracyStatisticMetricName.MapSegmentation };
                }
                if (s.name === FSCORE_SEGMENTATION_SLUG) {
                    return { ...s, name: DetectionModelTrainingAccuracyStatisticMetricName.FScoreSegmentation };
                }
                return s;
            });
        }

        const newSeriesGroups: Record<string, DetectionModelTrainingAccuracyStatisticSeries[]> = Object
            .values(DetectionModelTrainingAccuracyStatisticSeriesGroup)
            .reduce((acc, item) => ({ ...acc, [item]: [] }), {});

        const newSeriesToggleStatuses: Record<string, boolean> = {};

        series.forEach((item, i, arr) => {
            if (baseMetrics.includes(item.name as DetectionModelTrainingAccuracyStatisticMetricName)) {
                newSeriesGroups[DetectionModelTrainingAccuracyStatisticSeriesGroup.Base].push({ ...item, index: i });
                newSeriesToggleStatuses[item.name] = true;
            } else if (fScoreMetrics.includes(item.name as DetectionModelTrainingAccuracyStatisticMetricName)) {
                newSeriesGroups[DetectionModelTrainingAccuracyStatisticSeriesGroup.FScore].push({ ...item, index: i });
                newSeriesToggleStatuses[item.name] = false;
            } else {
                const newName = `${DetectionModelTrainingAccuracyStatisticSeriesGroup.ClassMap} (${item.name})`;
                newSeriesGroups[DetectionModelTrainingAccuracyStatisticSeriesGroup.ClassMap].push({
                    ...item, name: newName, index: i
                });
                newSeriesToggleStatuses[newName] = false;
                arr[i].name = newName;
            }
        });

        const labelsForYaxis = newSeriesGroups[DetectionModelTrainingAccuracyStatisticSeriesGroup.ClassMap].map(
            (item): ApexYAxis => ({
                show: false,
                max: 1,
                seriesName: item.name
            })
        );

        Object.entries(newSeriesGroups).forEach(([key, value]) => {
            if (!value.length) {
                delete newSeriesGroups[key];
            }
        });

        setSeriesGroups(newSeriesGroups);
        setSeriesToggleStatuses(newSeriesToggleStatuses);

        setChartData({
            series,
            options: {
                chart: {
                    id: chartId,
                    type: 'line',
                    toolbar: {
                        offsetY: -10,
                    },
                    zoom: {
                        enabled: true
                    },
                    animations: {
                        enabled: false
                    },
                    events: {
                        mounted: (chart) => {
                            const metricsForToggle = [
                                DetectionModelTrainingAccuracyStatisticMetricName.TotalLoss,
                                DetectionModelTrainingAccuracyStatisticMetricName.Map,
                                DetectionModelTrainingAccuracyStatisticMetricName.Map50,
                                DetectionModelTrainingAccuracyStatisticMetricName.Accuracy,
                                DetectionModelTrainingAccuracyStatisticMetricName.FScoreSegmentation,
                                DetectionModelTrainingAccuracyStatisticMetricName.MapSegmentation
                            ];
                            series.forEach((item) => {
                                if (!metricsForToggle.includes(item.name as DetectionModelTrainingAccuracyStatisticMetricName)) {
                                    chart.toggleSeries(item.name);
                                }
                            });
                            setChartIsLoading(false);
                        }
                    }
                },
                dataLabels: {
                    enabled: false
                },
                colors: chartColors,
                stroke: {
                    curve: 'straight',
                    width: 2
                },
                xaxis: {
                    decimalsInFloat: 0,
                    type: 'numeric',
                    title: {
                        text: 'Epochs',
                        offsetY: -10
                    }
                },
                yaxis: [
                    {
                        max: isObjectDetection ? undefined : 5,
                        seriesName: DetectionModelTrainingAccuracyStatisticMetricName.TotalLoss,
                        axisBorder: {
                            show: true
                        },
                        title: {
                            text: DetectionModelTrainingAccuracyStatisticMetricName.TotalLoss
                        },
                    },
                    {
                        max: 1,
                        seriesName: DetectionModelTrainingAccuracyStatisticMetricName.Map,
                        axisBorder: {
                            show: true
                        },
                        title: {
                            text: isObjectDetection
                                ? DetectionModelTrainingAccuracyStatisticMetricName.MapOrFScore
                                : DetectionModelTrainingAccuracyStatisticMetricName.Accuracy
                        },
                        opposite: true,
                        showAlways: true
                    },
                    {
                        show: false,
                        max: 1,
                        seriesName: DetectionModelTrainingAccuracyStatisticMetricName.Map50,
                    },
                    {
                        show: false,
                        max: 1,
                        seriesName: DetectionModelTrainingAccuracyStatisticMetricName.FScoreSmall
                    },
                    {
                        show: false,
                        max: 1,
                        seriesName: DetectionModelTrainingAccuracyStatisticMetricName.FScoreMedium
                    },
                    {
                        show: false,
                        max: 1,
                        seriesName: DetectionModelTrainingAccuracyStatisticMetricName.FScoreLarge
                    },
                    ...labelsForYaxis
                ],
                legend: {
                    show: false
                }
            }
        });
    }, [detectionModelsStore.itemTrainingAccuracyStatistic.data]);

    return (
        <>
            {detectionModelsStore.itemTrainingAccuracyStatistic.isLoading ? (
                <DnaLoader />
            ) : chartId && chartData ? (
                <>
                    {chartIsLoading && <DnaLoader />}
                    <Chart
                        type='line'
                        width='100%'
                        height={isObjectDetection ? '77.5%' : '85%'}
                        options={chartData.options}
                        series={chartData.series}
                    />
                    <TrainingAccuracyLegend
                        chartId={chartId}
                        chartColors={chartColors}
                        seriesGroups={seriesGroups}
                        seriesToggleStatuses={seriesToggleStatuses}
                        setSeriesToggleStatuses={setSeriesToggleStatuses}
                        isObjectDetection={isObjectDetection}
                    />
                </>
            ) : (
                <div className='d-flex justify-content-center align-items-center he-100'>
                    Nothing here yet!
                </div>
            )}
        </>
    );
});

export default TrainingAccuracyStatistic;
