import React, { FC, useEffect, useRef, useState } from 'react';

import { css } from '@emotion/react';
import { max, merge, sum, transpose } from 'd3-array';
import { axisBottom, axisLeft } from 'd3-axis';
import { scaleBand, scaleLinear, scaleOrdinal } from 'd3-scale';
import { select } from 'd3-selection';
import { stack } from 'd3-shape';
import cloneDeep from 'lodash-es/cloneDeep';

import { ChartAxisType } from '~/types';

import ChartDescription from '../ChartDescription';
import {
	getBasis,
	getFormat,
	getTicksAmount,
	sanitizeValue,
	scaleValue,
	valueOf
} from '../chartHelpers';
import { patternClasses } from '../PatternsSvg';

import type { Props as BarChartProps } from './BarChart';
import type { ChartElement } from '~/types/WebtextManifest';

const BarChart: FC<BarChartProps> = (props) => {
	const { chart, monochrome } = props;

	const chartData = cloneDeep(chart.data);

	const [seriesContainsNegativeValue, setSeriesContainsNegativeValue] = useState(false);

	const ref = useRef();

	useEffect(() => {
		const svg = select(ref.current);
		svg.selectAll('*').remove();

		draw();
	}, [chart]);

	const getStackedData = (data) => {
		const columns = data[0];
		const result = data.slice(1).map((row) =>
			columns.reduce((obj, k, i) => {
				const sanitizedValue = sanitizeValue(row[i]) || 0;
				return { ...obj, [k]: sanitizedValue };
			}, {})
		);

		result['columns'] ||= columns;

		return result;
	};

	const drawStackedChart = ({
		svg,
		layers,
		color,
		pattern,
		x,
		y,
		width,
		height,
		labelX,
		labelY,
		labelDX,
		labelDY,
		format,
		valueType
	}) => {
		const group = svg
			.selectAll('.groups')
			.data(layers)
			.enter()
			.append('g')
			.attr('fill', (d, i) => (monochrome ? `url(#${pattern(i)})` : color(i)));

		group
			.selectAll('rect')
			.data((d: any) => d)
			.enter()
			.append('rect')
			.attr('x', x)
			.attr('y', y)
			.attr('width', width)
			.attr('height', height);

		if (chart.show_labels)
			group
				.selectAll('text')
				.data((d: any) => d)
				.enter()
				.append('text')
				.attr('class', 'label light')
				.attr('x', labelX)
				.attr('y', labelY)
				.attr('dy', labelDY)
				.attr('dx', labelDX)
				.attr('text-anchor', 'middle')
				/**
				 * The received data look like [0, 81.6, data: {…}].
				 * In this case format function will create unexpected result like `8160%` instead of `82%`
				 * Happens because a number between 0 and 1 is expected, but an unformatted number is used 81.6 - 0 = 81.6.
				 */
				.text((d: any) => {
					const value = d[1] - d[0];
					return format(valueType === 'percent' ? value / 100 : value);
				})
				.attr('fill', '#ffffff');

		return group;
	};

	const drawGroupedChart = ({
		svg,
		categories,
		color,
		pattern,
		transform,
		x,
		y,
		width,
		height,
		labelX,
		labelY,
		labelDX,
		labelDY,
		format
	}) => {
		const group = svg
			.selectAll('.groups')
			.data(categories)
			.enter()
			.append('g')
			.attr('class', 'g')
			.attr('transform', transform);

		const horizontalGrouped = group
			.selectAll('rect')
			.data((d: any) => d)
			.enter()
			.append('rect')
			.attr('x', x)
			.attr('y', y)
			.attr('width', width)
			.attr('height', height);

		horizontalGrouped
			.attr('fill', (d, i) => (monochrome ? `url(#${pattern(i)})` : color(i)))
			.attr('stroke', (d, i) => (monochrome ? '#000000' : color(i)));

		if (chart.show_labels)
			group
				.selectAll('text')
				.data((d: any) => d)
				.enter()
				.append('text')
				.attr('class', 'label')
				.attr('x', labelX)
				.attr('y', labelY)
				.attr('dx', labelDX)
				.attr('dy', labelDY)
				.attr('text-anchor', 'middle')
				.attr('alignment-baseline', 'middle')
				.text((d: any) => format(d.value));
	};

	const draw = () => {
		const seriesLabels = chartData.shift().slice(1);

		let categories: any = transpose(chartData);
		const categoryLabels: Array<any> = categories.shift();
		categories = transpose(categories).map((row) =>
			row.map((value: any) => {
				setSeriesContainsNegativeValue(/^-/.test(value));
				return valueOf(value);
			})
		);

		const margin = {
			top: (+chart.margins[0] || 0) + 33,
			right: (+chart.margins[1] || 0) + 33,
			bottom: (+chart.margins[2] || 0) + 33,
			left: (+chart.margins[3] || 0) + 33
		};

		const staticAxisLabelMargin = 40;

		margin.bottom += chart.x_axis_label?.length > 0 ? staticAxisLabelMargin : 0;
		margin.left += chart.y_axis_label?.length > 0 ? staticAxisLabelMargin : 0;

		const viewBoxWidth = 700;
		const viewBoxHeight = 400;

		const width = viewBoxWidth - margin.left - margin.right;
		const height = viewBoxHeight - margin.top - margin.bottom;
		const color: any = scaleOrdinal().range(chart.colors);
		const pattern: any = scaleOrdinal().range(patternClasses);

		const svg = select(ref.current)
			.attr('class', 'chart')
			.attr('viewBox', `0 0 ${viewBoxWidth} ${viewBoxHeight}`)
			.attr('preserveAspectRatio', 'xMidYMid meet')
			.append('g')
			.attr('transform', `translate(${margin.left}, ${margin.top})`);

		const groupsScale = scaleBand().domain(categoryLabels);
		const seriesScale = scaleBand().domain(seriesLabels);
		const valuesScale = scaleLinear();

		switch (chart.orientation) {
			case 'horizontal':
				groupsScale.range([0, height]).padding(0.1);
				valuesScale.range([0, width]);
				break;
			case 'vertical':
				groupsScale.range([0, width]).padding(0.1);
				valuesScale.range([height, 0]);
				break;
		}

		const valuesLowerBound = sanitizeValue(chart.min_bound) || 0;
		let valuesUpperBound = sanitizeValue(chart.max_bound) || 0;

		let layers;

		switch (chart.series_orientation) {
			case 'stacked': {
				const maxGroupSum: any = max(categories.map((a) => sum(a.map((d) => d.value))));
				valuesUpperBound ||= maxGroupSum;

				const data = getStackedData(chart.data);
				layers = stack().keys(seriesLabels)(data);
				break;
			}

			case 'grouped': {
				const maxValue = max(merge(categories), (d: any) => d.value);
				valuesUpperBound ||= maxValue;

				seriesScale.rangeRound([0, groupsScale.bandwidth()]);
				break;
			}
		}

		const basis = getBasis(valuesLowerBound, valuesUpperBound, categories[0][0].type);
		const valuesScaleLowerBound = scaleValue(valuesLowerBound, basis);
		const valuesScaleUpperBound = scaleValue(valuesUpperBound, basis);

		valuesScale.domain([valuesScaleLowerBound, valuesScaleUpperBound]);

		const valueType = categories[0][0].type;
		const formatValue = getFormat(valueType);
		const ticksAmount = getTicksAmount(valuesLowerBound, valuesUpperBound, valueType === 'percent');

		let xAxis, yAxis;
		switch (chart.orientation) {
			case 'horizontal':
				xAxis = axisBottom(valuesScale).ticks(ticksAmount).tickFormat(formatValue);
				yAxis = axisLeft(groupsScale).tickSize(0).tickPadding(6);
				break;
			case 'vertical':
				xAxis = axisBottom(groupsScale).tickSize(0).tickPadding(6);
				yAxis = axisLeft(valuesScale).ticks(ticksAmount).tickFormat(formatValue);
				break;
		}

		let chartSettings;
		switch (true) {
			case chart.orientation === 'horizontal' && chart.series_orientation === 'stacked':
				chartSettings = {
					svg,
					layers,
					color,
					pattern,
					x: (d: any) => valuesScale(d[0] / basis),
					y: (d, i) => groupsScale(categoryLabels[i]),
					width: (d: any) => valuesScale((d[1] - d[0]) / basis),
					height: groupsScale.bandwidth(),
					labelX: (d: any) =>
						valuesScale(d[0] / basis) + (valuesScale(d[1] / basis) - valuesScale(d[0] / basis)) / 2,
					labelY: (d, i) => groupsScale(categoryLabels[i]),
					labelDX: '0.5em',
					labelDY: groupsScale.bandwidth() / 2 + 5,
					format: formatValue,
					valueType
				};

				drawStackedChart(chartSettings);
				break;

			case chart.orientation === 'horizontal' && chart.series_orientation === 'grouped':
				chartSettings = {
					svg,
					categories,
					color,
					pattern,
					transform: (d, i) => `translate(0, ${groupsScale(categoryLabels[i])})`,
					x: (d: any) => (d.value < 0 ? -Math.abs(valuesScale(d.value) - valuesScale(0)) : 0),
					y: (d, i) => seriesScale(seriesLabels[i]),
					width: (d: any) =>
						d.value < 0 ? Math.abs(valuesScale(d.value) - valuesScale(0)) : valuesScale(d.value),
					height: seriesScale.bandwidth(),
					labelX: (d: any) => valuesScale(d.value) + 3,
					labelY: (d, i) => seriesScale(seriesLabels[i]),
					labelDX: '1.1em',
					labelDY: seriesScale.bandwidth() / 2,
					format: formatValue
				};

				drawGroupedChart(chartSettings);
				break;

			case chart.orientation === 'vertical' && chart.series_orientation === 'stacked':
				chartSettings = {
					svg,
					layers,
					color,
					pattern,
					x: (d, i) => groupsScale(categoryLabels[i]),
					y: (d: any) => valuesScale(d[1]),
					width: groupsScale.bandwidth(),
					height: (d: any) => valuesScale(d[0]) - valuesScale(d[1]),
					labelX: (d, i) => groupsScale(categoryLabels[i]),
					labelY: (d: any) => valuesScale(d[0]) + (valuesScale(d[1]) - valuesScale(d[0])) / 2,
					labelDX: groupsScale.bandwidth() / 2,
					labelDY: '0',
					format: formatValue
				};

				drawStackedChart(chartSettings);
				break;

			case chart.orientation === 'vertical' && chart.series_orientation === 'grouped':
				chartSettings = {
					svg,
					categories,
					color,
					pattern,
					transform: (d, i) => `translate(${groupsScale(categoryLabels[i])}, 0)`,
					x: (d, i) => seriesScale(seriesLabels[i]),
					y: (d: any) => (d.value > 0 ? valuesScale(d.value) : valuesScale(0)),
					width: seriesScale.bandwidth(),
					height: (d: any) =>
						seriesContainsNegativeValue
							? Math.abs(valuesScale(d.value) - valuesScale(0))
							: height - valuesScale(d.value),
					labelX: (d, i) => seriesScale(seriesLabels[i]),
					labelY: (d: any) => valuesScale(d.value) - 4,
					labelDX: seriesScale.bandwidth() / 2,
					labelDY: '-0.3em',
					format: formatValue
				};

				drawGroupedChart(chartSettings);

				if (seriesContainsNegativeValue) {
					svg.append('g').attr('class', 'x axis zero');
					svg
						.select('.x.axis.zero')
						.attr('transform', 'translate(0,' + valuesScale(0) + ')')
						.call(xAxis.tickFormat('').tickSize(0));
				}

				break;
		}

		const x_labels = svg
			.append('g')
			.attr('class', 'x axis')
			.attr('transform', `translate(0,${height})`)
			.call(xAxis);

		const y_labels = svg.append('g').attr('class', 'y axis').call(yAxis);

		if (chart.orientation === 'vertical' && chart.axis_type === ChartAxisType.tilted) {
			x_labels
				.selectAll('text')
				.attr('text-anchor', 'end')
				.attr('dx', '-0.4em')
				.attr('dy', '.4em')
				.attr('transform', 'rotate(-45)');
		} else if (chart.orientation === 'horizontal' && chart.axis_type === ChartAxisType.tilted) {
			y_labels
				.selectAll('text')
				.attr('text-anchor', 'end')
				.attr('dx', '-0.5em')
				.attr('dy', '-.5em')
				.attr('transform', 'rotate(-45)');
		}

		if (chart.x_axis_label?.length > 0) {
			svg
				.append('text')
				.attr('class', 'x axis-label')
				.attr('text-anchor', 'middle')
				.attr('x', width / 2)
				.attr('y', height + 40)
				.text(chart.x_axis_label);
		}

		if (chart.y_axis_label?.length > 0) {
			let yLabelsMaxWidth;
			y_labels.selectAll('text').each(function () {
				const yLabelWidth = (this as any).getBBox().width;
				if (yLabelWidth > yLabelsMaxWidth || !yLabelsMaxWidth) yLabelsMaxWidth = yLabelWidth;
			});

			const yAxisLabelMarginShift = 15;
			margin.left = yLabelsMaxWidth + staticAxisLabelMargin;

			svg
				.attr('transform', `translate(${margin.left}, ${margin.top})`)
				.append('g')
				.attr('transform', `translate(${yAxisLabelMarginShift - margin.left}, ${height / 2} )`)
				.append('text')
				.attr('class', 'y axis-label')
				.attr('text-anchor', 'middle')
				.attr('transform', 'rotate(-90)')
				.text(chart.y_axis_label);
		}
	};

	return (
		<div css={styles} data-bar-chart>
			<svg ref={ref} aria-hidden />
			<ChartDescription chart={chart as ChartElement} />
		</div>
	);
};

const styles = (theme) => css`
	.label {
		font-family: ${theme.fonts.app};
		font-size: 13px;
	}

	.x,
	.y {
		&.axis {
			font-size: 11px;
		}

		&.axis-label {
			font-size: 15px;
			font-family: ${theme.fonts['haas-grotesk-all']};
		}
	}
`;

export default BarChart;
