import React, { memo, useCallback, useEffect, useMemo, useRef, useState } from "react"
import { message } from "antd"
import { ScatterPlot as ScatterPlotDataTypes } from "src/components/AIEngine/DataSummary/types"
import useTranslate from "src/utils/useTranslate"
import { SummaryWrapper } from "../summary-wrapper"
import Highcharts from "highcharts"
import HighchartsReact from "highcharts-react-official"
import { useSelector } from "react-redux"
import { StoreState } from "src/store/configureStore"
import { PropertySelectionForm } from "./form"
import { DisplayNames } from "src/typings"
import { convertToPrecision } from "src/utils/decorator"
import { PlotType } from "src/components/Analytics/enums"
import { useValue } from "src/utils/useValue"

type P = {
	scatterPlotData: ScatterPlotDataTypes
	names: DisplayNames
	getDisplayName: (value: string) => string,
	setSaveGraphsPayload: any,
	selectedPlotsData: { [key: string]: any }
}

type PlotData = {
	exp: string
	x: number
	y: number
}
const initialPlotData: PlotData = { exp: "", x: 2, y: 3 }

export const ScatterPlot = memo(({ scatterPlotData, names, getDisplayName, setSaveGraphsPayload, selectedPlotsData }: P) => {
	const [t] = useTranslate()
	const [selectedX1, setSelectedX1] = useState<string | null>(null)
	const [selectedY1, setSelectedY1] = useState<string | null>(null)
	const [selectedX2, setSelectedX2] = useState<string | null>(null)
	const [selectedY2, setSelectedY2] = useState<string | null>(null)
	const [plotData1, setPlotData1] = useState<PlotData[]>([initialPlotData])
	const [plotData2, setPlotData2] = useState<PlotData[]>([initialPlotData])
	const { data: numericalSummary } = useSelector((state: StoreState) => state.dataSummary)
	const { getValue } = useValue()

	const getTrendLine = useCallback((data: any[]) => {
		const n = data.length;

		let sumX = 0,
			sumY = 0,
			sumXY = 0,
			sumX2 = 0

		// Calculate the sums needed for linear regression
		data.forEach((res: any) => {
			const { x, y } = res
			sumX += x
			sumY += y
			sumXY += x * y
			sumX2 += x ** 2
		})

		// Calculate the slope of the trend line
		const slope = (n * sumXY - sumX * sumY) / (n * sumX2 - sumX ** 2)

		// Calculate the intercept of the trend line
		const intercept = (sumY - slope * sumX) / n

		const trendline = [] // Array to store the trend line data points

		// Find the minimum and maximum x-values from the scatter plot data
		const minX = Math.min(...data.map((res) => res?.x))
		const maxX = Math.max(...data.map((res) => res?.x))

		// Calculate the corresponding y-values for the trend line using the slope and intercept
		trendline.push([minX, minX * slope + intercept])
		trendline.push([maxX, maxX * slope + intercept])

		return trendline
	}, [])

	useEffect(() => {
		if (Object.keys(selectedPlotsData?.data?.plot_data ?? {}).includes(PlotType.SCATTER_PLOT)) {
			const data = selectedPlotsData?.data?.plot_data?.[PlotType.SCATTER_PLOT] ?? {}
			if (data.x1 && scatterPlotData?.data?.[data.x1]) {
				setSelectedX1(data.x1)
			} else {
				setSelectedX1(null)
			}
			if (data.x2 && scatterPlotData?.data?.[data.x2]) {
				setSelectedX2(data.x2)
			} else {
				setSelectedX2(null)
			}
			if (data.y1 && scatterPlotData?.data?.[data.y1]) {
				setSelectedY1(data.y1)
			} else {
				setSelectedY1(null)
			}
			if (data.y2 && scatterPlotData?.data?.[data.y2]) {
				setSelectedY2(data.y2)
			} else {
				setSelectedY2(null)
			}
		}
	}, [scatterPlotData?.data, selectedPlotsData?.data?.plot_data])

	useEffect(() => {
		if (selectedX1 && selectedY1) {
			const dataPointX = scatterPlotData.data?.[selectedX1] ?? {}
			const dataPointY = scatterPlotData.data?.[selectedY1] ?? {}
			const dataPointXvalue = Object.values(dataPointX)
			const experiments = Object.keys(dataPointX)
			const points2 = experiments.map((exp) => dataPointY[exp])
			const mergedData: PlotData[] = dataPointXvalue.map((point, index) => ({
				exp: experiments[index],
				x: convertToPrecision(point),
				y: convertToPrecision(points2[index]),
			}))

			if (!mergedData.some(({ x, y }) => x !== null && y !== null))
				message.error(t("dataSummary.scatterPlot.noDataFound"))

			setPlotData1(mergedData)
		}
	}, [selectedX1, selectedY1, scatterPlotData, t])

	useEffect(() => {
		if (selectedX2 && selectedY2) {
			const plotData = scatterPlotData?.data
			const dataPointX = plotData[selectedX2] ?? {}
			const dataPointXvalue = Object.values(dataPointX)
			const experiments = Object.keys(dataPointX)
			const dataPointY = plotData[selectedY2] ?? {}
			const points2 = experiments.map((exp) => dataPointY[exp])

			const mergedData: PlotData[] = dataPointXvalue.map((point, index) => ({
				exp: experiments[index],
				x: convertToPrecision(point),
				y: convertToPrecision(points2[index]),
			}))

			if (!mergedData.some(({ x, y }) => x !== null && y !== null))
				message.error(t("dataSummary.scatterPlot.noDataFound"))

			setPlotData2(mergedData)
		}
	}, [selectedX2, selectedY2, scatterPlotData, t])


	const highchartsScatterPlotOption1 = useMemo(() => {
		const options = {

			title: { text: "" },
			xAxis: {
				labels: {
					formatter: function (this: any) {
						const point = getValue(this.value)
						return point;
					}
				},
				title: { text: selectedX1 ? getDisplayName(selectedX1) : "X Axis" },
			},
			yAxis: {
				labels: {
					formatter: function (this: any) {
						const point = getValue(this.value)
						return point;
					}
				},
				title: { text: selectedY1 ? getDisplayName(selectedY1) : "Y Axis" },
			},
			credits: {
				enabled: false,
			},
			tooltip: {
				useHTML: true,
				headerFormat: "<small>{series.name}</small><br/>",
				pointFormatter: function (this: any) {
					const point = this
					return `${numericalSummary?.scatter_plot?.formulation_display_ids?.[point?.exp] ?? ""}<br><br> x:${getValue(point.x)}<br> y: ${getValue(point.y)}`
				},
			},
			series: [
				{
					name: "Experiments",
					data: plotData1,
					type: 'scatter',
				},
				{
					type: 'line',
					name: 'Best Fit Line',
					data: getTrendLine(plotData1) ?? [],
					marker: {
						enabled: false
					},
					states: {
						hover: {
							lineWidth: 0
						}
					},
					enableMouseTracking: false,
					dashStyle: 'dash',
					lineWidth: 6,
					color: "#FF0000"
				}
			],
		}
		return options
	}, [selectedX1, getDisplayName, selectedY1, plotData1, getTrendLine, getValue, numericalSummary?.scatter_plot?.formulation_display_ids])

	const highchartsScatterPlotOption2 = useMemo(() => {

		const options = {
			title: { text: "" },
			xAxis: {
				labels: {
					formatter: function (this: any) {
						const point = getValue(this.value)
						return point;
					}
				},
				title: { text: selectedX2 ? getDisplayName(selectedX2) : "X Axis" },
			},
			yAxis: {
				labels: {
					formatter: function (this: any) {
						const point = getValue(this.value)
						return point;
					}
				},
				title: { text: selectedY2 ? getDisplayName(selectedY2) : "Y Axis" },
			},
			credits: {
				enabled: false,
			},
			tooltip: {
				useHTML: true,
				headerFormat: "<small><b>{series.name}</b></small><br/>",
				pointFormatter: function (this: any) {
					const point = this
					return `${numericalSummary?.scatter_plot?.formulation_display_ids?.[point?.exp] ?? ""}<br><br> x:${getValue(point.x)}<br> y: ${getValue(point.y)}`
				},
			},
			symbol: "triangle-down",
			series: [
				{
					name: "Experiments",
					data: plotData2,
					type: 'scatter',
				},
				{
					type: 'line',
					name: 'Best Fit Line',
					data: getTrendLine(plotData2) ?? [],
					marker: {
						enabled: false
					},
					states: {
						hover: {
							lineWidth: 0
						}
					},
					enableMouseTracking: false,
					dashStyle: 'dash',
					lineWidth: 6,
					color: "#FF0000"
				},
			],
		}
		return options
	}, [plotData2, selectedX2, selectedY2, getDisplayName, numericalSummary, getValue, getTrendLine])

	useEffect(() => {
		if ((!!selectedY1 && !!selectedX1) || (!!selectedY2 && !!selectedX2)) {
			setSaveGraphsPayload((prev: any) => ({
				...prev,
				[PlotType.SCATTER_PLOT]: {
					x1: selectedX1,
					y1: selectedY1,
					x2: selectedX2,
					y2: selectedY2
				}
			}))
		} else {
			setSaveGraphsPayload((prev: any) => {
				const prevState = JSON.parse(JSON.stringify(prev))
				if (prevState?.[PlotType.SCATTER_PLOT]) {
					delete prevState[PlotType.SCATTER_PLOT]
					return prevState
				}
				return prevState
			})
		}
	}, [selectedX1, selectedX2, selectedY1, selectedY2, setSaveGraphsPayload])


	return (
		<SummaryWrapper heading={t("dataSummary.scatterPlot")} tooltip={t("aiEngine.tab.dataSummary.scatterplot")} id="scatter-plot" key="scatter-plot">
			<div
				style={{
					width: "100%",
					display: "grid",
					gridTemplateColumns: "repeat(2, minmax(0,1fr))",
					gap: 2,
				}}
			>
				<div>
					<PropertySelectionForm
						key={"plot_1_form"}
						data={scatterPlotData.data}
						names={names}
						value1={selectedX1}
						setValue1={setSelectedX1}
						value2={selectedY1}
						setValue2={setSelectedY1}
						getDisplayName={getDisplayName}
					/>

					{(!!selectedY1 || !!selectedX1) && (
						<ScatterPlotChart options={highchartsScatterPlotOption1} key={"plot_1"} />
					)}
				</div>

				<div>
					<PropertySelectionForm
						key={"plot_2_form"}
						data={scatterPlotData.data}
						names={names}
						value1={selectedX2}
						setValue1={setSelectedX2}
						value2={selectedY2}
						setValue2={setSelectedY2}
						getDisplayName={getDisplayName}
					/>

					{(!!selectedY2 || !!selectedX2) && (
						<ScatterPlotChart options={highchartsScatterPlotOption2} key={"plot_2"} />
					)}
				</div>
			</div>
		</SummaryWrapper>
	)
})

const ScatterPlotChart = ({ options }: any) => {
	const chartRef = useRef<any>(null)
	const isSidebarCollapsed = useSelector((state: StoreState) => state.sidebar.collapsed)

	useEffect(() => {
		if (chartRef.current) {
			chartRef?.current?.chart?.redraw()
		}
	}, [isSidebarCollapsed])

	return <HighchartsReact
		highcharts={Highcharts}
		options={options}
		ref={chartRef}
	/>
}