import * as d3 from 'd3';
import ShapChart from './ShapChart';
import collapseOther from '../math/collapseOther';

const DEFAULT_WIDTH = 800;
const DEFAULT_HEIGHT = 450;
const DEFAULT_MARGINS = {
  top: 72,
  right: 36,
  bottom: 48,
  left: 380,
};

export default class ShapWaterfallChart extends ShapChart {
  constructor(selector, width = DEFAULT_WIDTH, height = DEFAULT_HEIGHT, margins = DEFAULT_MARGINS) {
    super(selector, width, height, margins);

    this.chart.classed('chart-waterfall', true);

    this.xZero = this.chart.append('line')
      .classed('chart-zero-line', true)
      .attr('opacity', 0.75)
      .attr('stroke', '#000000')
      .attr('stroke-dasharray', 2)
      .attr('x1', this.width - this.margin.right)
      .attr('x2', this.width - this.margin.right)
      .attr('y1', 20)
      .attr('y2', this.height - this.margin.bottom);

    this.prediction = this.chart.append('text')
      .classed('chart-axis-title', true)
      .attr('fill', 'currentColor')
      .attr('font-size', 10)
      .attr('transform', `translate(${this.width - this.margin.right}, 0)`);

    this.predictionLabel = 'Prediction';
    this.predictionTitle = this.prediction.append('tspan')
      .text(this.predictionLabel)
      .attr('x', -18)
      .attr('y', 24)
      .attr('dominant-baseline', 'central')
      .attr('text-anchor', 'end');

    this.predictionValue = this.prediction.append('tspan')
      .text('ƒ(x) = 0')
      .attr('x', -18)
      .attr('y', 36)
      .attr('dominant-baseline', 'central')
      .attr('font-weight', 'bold')
      .attr('text-anchor', 'end');
  }

  static buildStack(data, labels) {
    return d3.stack()
      .offset(d3.stackOffsetNone)
      .keys(labels)
      .value((d, key) => {
        const entry = d.features[key];
        return entry ? entry.importance : 0;
      })(data);
  }

  render(datum, meta = {}) {
    // We first convert the single datum to a data array
    const data = [datum];
    const sortIndex = 0;

    const topLabels = Object.keys(data[sortIndex].features).slice(0, 10);
    const topValues = collapseOther(data, topLabels);

    const labels = Object.keys(topValues[sortIndex].features).reverse();
    const values = this.constructor.buildStack(topValues, labels)
      .map((f) => ({
        key: f.key,
        start: f[0][0] + datum.expected,
        final: f[0][1] + datum.expected,
      }));

    const yScale = d3.scaleBand()
      .padding(0.1)
      .range([this.height - this.margin.bottom, this.margin.top])
      .domain(labels);

    const xDomain = [
      d3.min(values, (d) => Math.min(d.start, d.final)),
      d3.max(values, (d) => Math.max(d.start, d.final)),
    ];
    const xScale = d3.scaleLinear()
      .range([this.margin.left, this.width - this.margin.right])
      .domain(xDomain);

    this.updateTooltips(topValues, meta);

    this.yAxisShape.scale(yScale)
      .tickPadding(12)
      .tickSize(0)
      .tickValues(labels);

    this.yAxis.call(this.yAxisShape);
    this.yAxis.selectAll('.tick > text')
      .text((d) => (this.tooltipLookup[d] ? this.tooltipLookup[d].label : d));

    this.xAxisFormat = d3.format('.4f');
    this.xAxisShape.scale(xScale)
      .tickPadding(6)
      .tickSize(6)
      .ticks(5);

    this.xAxis.call(this.xAxisShape);

    const xFinalValue = values.slice(-1).pop().final;
    const xFinalPos = xScale(xFinalValue);

    const xMiddle = xDomain[0] + ((xDomain[1] - xDomain[0]) / 2);

    this.prediction.attr('transform', `translate(${[xFinalPos, 0]})`);

    this.predictionTitle.text(this.predictionLabel);
    this.predictionValue.text(`ƒ(x) = ${this.xAxisFormat(xFinalValue)}`);

    if (xFinalValue >= xMiddle) {
      this.predictionTitle.attr('text-anchor', 'end').attr('x', -18);
      this.predictionValue.attr('text-anchor', 'end').attr('x', -18);
    } else {
      this.predictionTitle.attr('text-anchor', 'start').attr('x', 18);
      this.predictionValue.attr('text-anchor', 'start').attr('x', 18);
    }

    this.xZero.attr('x1', xFinalPos);
    this.xZero.attr('x2', xFinalPos);

    const zScale = d3.scaleQuantile()
      .range(['#333399', '#993333'])
      .domain([-1, 1]);

    this.shapes.selectAll('.chart-series-rect')
      .data(values)
      .join('polygon')
      .classed('chart-series-rect', true)
      .attr('fill', (d) => zScale(d.final - d.start))
      .on('mouseenter.tooltip', this.showTooltip.bind(this))
      .on('mouseleave.tooltip', this.hideTooltip.bind(this))
      .attr('points', (d) => {
        const xStart = xScale(d.start);
        const xFinal = xScale(d.final);
        const xDelta = xFinal - xStart;
        const xWidth = Math.abs(xDelta);

        const yStart = yScale(d.key);
        const yWidth = yScale.bandwidth();
        const yFinal = yStart + yWidth;

        const xArrow = Math.min(6, xWidth) * (xDelta / xWidth);
        const yArrow = yStart + yWidth / 2;

        const coords = [
          [xStart, yStart],
          [xFinal - xArrow, yStart],
          [xFinal, yArrow],
          [xFinal - xArrow, yFinal],
          [xStart, yFinal],
        ];

        return coords.map((point) => point.join(',')).join(' ');
      });

    this.shapes.selectAll('.chart-series-line')
      .data(values.slice(0, -1))
      .join('polyline')
      .classed('chart-series-line', true)
      .attr('opacity', 0.5)
      .attr('stroke', '#000000')
      .attr('stroke-dasharray', 1)
      .attr('stroke-width', 1)
      .attr('points', (d) => {
        const x = xScale(d.final);
        const y = yScale(d.key);

        const coords = [
          [x, y + yScale.bandwidth() / 2],
          [x, y - yScale.step()],
        ];

        return coords.map((point) => point.join(',')).join(' ');
      });
  }
}
