import React, { useState } from 'react';
import MetaTags from 'react-meta-tags';

import { Link, useParams } from 'react-router-dom';
import Logo from "../../logo.png";

import {
  Tooltip,
  XAxis,
  YAxis,
  CartesianGrid,
  Legend,
  Line,
  LineChart,
  Bar,
  BarChart,
  Cell,
  ReferenceLine,
} from 'recharts';

import { makeStyles, createStyles, Theme } from '@material-ui/core/styles';

import CircularProgress from '@material-ui/core/CircularProgress';
import Paper from '@material-ui/core/Paper';
import Grid from '@material-ui/core/Grid';
import Divider from '@material-ui/core/Divider';
import Table from '@material-ui/core/Table';
import TableBody from '@material-ui/core/TableBody';
import TableHead from '@material-ui/core/TableHead';
import TableCell from '@material-ui/core/TableCell';
import TableRow from '@material-ui/core/TableRow';
import {MuiPickersUtilsProvider, KeyboardDatePicker} from '@material-ui/pickers';

import DateFnsUtils from '@date-io/date-fns';

import { useGetOrthogonalModelQuery } from '../../services/ReportingService';
import { COLORS } from '../../const';


interface ParamTypes {
  id: string;
}

const standardDeviation = (arr: Array<number>) => {
  const mean = arr.reduce((acc: number, val: number) => acc + val, 0) / arr.length;
  return Math.sqrt(
    arr.reduce((acc: any, val: number) => acc.concat((val - mean) ** 2), []).reduce((acc: any, val: number) => acc + val, 0) /
      (arr.length)
  );
};


const useStyles = makeStyles((theme: Theme) =>
  createStyles({
    root: {
      flexGrow: 1,
    },
    paper: {
      padding: theme.spacing(2),
      textAlign: 'center',
      color: theme.palette.text.secondary,
    },
    table: {
      minWidth: 250,
    },
    heading: {
      fontSize: theme.typography.pxToRem(15)
    }
  }),
);


export default function OrthogonalModelReport(){
  let { id } = useParams<ParamTypes>();
  const classes = useStyles();
  const { data, error, isLoading } = useGetOrthogonalModelQuery(id);

  let content = <p>No request sent to server</p>;
  let description = "Orthogonal Model Report"

  const [startDate, setStartDate] = useState(0)
  const [endDate, setEndDate] = useState(0)

  if (isLoading){
    content = (
      <h2>
        Building orthogonal model report...
        <CircularProgress />
      </h2>
    );
  } else if (error){
    content = (
      <div>
        <h3>Error building orthogonal model report</h3>
        <p>{JSON.stringify(error)}</p>
      </div>
    );
  } else if (data) {
    const rows = []
    const description_elements = []

    const optionBar = (
      <Paper elevation={2} style={{height: 80, padding: 10, margin: 0}}>
        <MuiPickersUtilsProvider utils={DateFnsUtils}>
          <Grid container justifyContent="space-around">
            <KeyboardDatePicker
              disableToolbar
              variant="inline"
              format="MM/dd/yyyy"
              id="start-date-picker-inline"
              label="Start Date"
              value={startDate? startDate : null}
              onChange={(date: Date | null) => {if(date){setStartDate((new Date(date)).getTime());}}}
              KeyboardButtonProps={{'aria-label': 'change start date'}}
            />
            <KeyboardDatePicker
              disableToolbar
              variant="inline"
              format="MM/dd/yyyy"
              id="end-date-picker-inline"
              label="End Date"
              value={endDate? endDate : null}
              onChange={(date: Date | null) => {if(date){setEndDate((new Date(date)).getTime());}}}
              KeyboardButtonProps={{'aria-label': 'change end date'}}
            />
          </Grid>
        </MuiPickersUtilsProvider>
      </Paper>
    )

    rows.push(
      <Grid container spacing={3} key="optionBarGrid" style={{ "marginBottom": "5px" }}>
        <Grid item xs={12}>
          {optionBar}
        </Grid>
      </Grid>
    )

    for ( const el in data ) {
      const {sharpe_ratio_by_lag, cum_returns_by_lag, feature} = data[el]
      const cum_returns_by_lag_formatted = JSON.parse(JSON.stringify(cum_returns_by_lag))
      for (let i=0; (i<cum_returns_by_lag.length); i++) {
        cum_returns_by_lag_formatted[i].ts = (new Date(cum_returns_by_lag[i].ts)).getTime()
      }
      const lags = sharpe_ratio_by_lag.filter(entry => (entry.lag !== "Average")).map(entry => entry.lag) 
      //^used to assign colors to each non-Average lag across all graphs

      const getSharpes = (from: number, to: number) => {
        const sharpes : Array<any> = []; //array containing e.g. {"lag": 2, feature: sharpe} (same as sharpe_ratio_by_lag)
        for ( const el in lags ) {
          const sharpes_el : any = {lag: lags[el]};
          sharpes_el[feature] = null;
          sharpes.push(sharpes_el);
        }
        const avgSharpe : any = { lag: "Average" };
        avgSharpe[feature] = null;
	sharpes.unshift(avgSharpe);
       
        const dontFlip = (from <= to) || (!from) || (!to);
        const refData = dontFlip? cum_returns_by_lag_formatted.filter((entry: any) => (((!from) || (from <= entry.ts)) && ((!to) || (entry.ts <= to)))) :
          cum_returns_by_lag_formatted.filter((entry: any) => (((!from) || (from >= entry.ts)) && ((!to) || (entry.ts >= to))));
        if (refData.length > 3) {
          for (let key in refData[0]) {
            if (key !== "ts") {
              const lagData = refData.filter((e: any) => (e[key] !== null));
              if (lagData.length > 3) {
                const avgReturn = (lagData[lagData.length-1][key] - lagData[0][key]) / lagData.length;
                const returns = lagData.map((v: any) => v[key]).map((v: number, i: number, a: Array<number>) => (a[i-1])? v - a[i - 1] : null).filter((v: any) => (v !== null));
                const std = standardDeviation(returns);
                const sharpe = (dontFlip? 1 : -1) * Math.sqrt(252) * avgReturn / std;
                const sharpesIndex = (key !== "avg")? sharpes.findIndex((e: any) => (e.lag === parseInt(key))) : 0;
                if (sharpesIndex >= 0) sharpes[sharpesIndex][feature] = sharpe;
              }
            }    
          }
        }
        return sharpes;
      }

      const sharpe_ratio_by_lag_in_range = getSharpes(startDate, endDate);
      const composite_sharp = (
        <BarChart width={900} height={200} data={sharpe_ratio_by_lag_in_range} margin={{top: 5, right: 30, left: 20, bottom: 3}}>
          <XAxis dataKey="lag" label={{value: "Lag"}}/>
          <CartesianGrid strokeDasharray="3 3" />
          <YAxis type="number" label={{ value: "Sharpe", angle: -90, position: 'insideLeft' }} />
          <Tooltip />
          <Bar dataKey={feature}>
          {
            sharpe_ratio_by_lag_in_range.map((entry, index) => (
              lags.includes(entry.lag)? <Cell key={`bar-cell-${index}`} fill={COLORS[lags.indexOf(entry.lag)]}/> : 
                <Cell key="bar-cell-Average" fill="#363636"/>
            ))
          }
        </Bar>
        </BarChart>
      );
      let avg_sharpe = 0;
      for ( const sharpe_el in sharpe_ratio_by_lag_in_range){
        if( sharpe_ratio_by_lag_in_range[sharpe_el].lag === "Average"){
          avg_sharpe = Number(sharpe_ratio_by_lag_in_range[sharpe_el][feature]);
        }
      }
      description_elements.push(`${feature}: Avg Sharpe ${Math.round(avg_sharpe * 10) / 10}`)


      const getAxisYDomain = (from: number, to: number, avg: string="avg") => {
      const refData = cum_returns_by_lag_formatted.filter((entry: any) => (((!from) || (from <= entry.ts)) && ((!to) || (entry.ts <= to))));
      if (!refData.length) return ["auto", "auto"];
        let [bottom, top] = [refData[0][avg], refData[0][avg]];
        refData.forEach((d: any) => {
          for (let key in d) {
            if (((d[key]) || (d[key]===0)) && (key !== "ts")) {
              if (d[key] > top) top = d[key];
              if (d[key] < bottom) bottom = d[key];
            }
          }
        });
        return [Math.min(1.4*bottom, 0.0*bottom), Math.max(1.1*top, 0.9*top)];
      };

      const cum_returns_lines = [<Line type="monotone" dataKey="avg" dot={false} strokeWidth="2.5" stroke="#363636"/>];
      for (let key in cum_returns_by_lag_formatted[0]){
        if (lags.includes(parseInt(key))) {
          cum_returns_lines.push(<Line type="monotone" dataKey={key} dot={false} strokeOpacity="0.3" stroke={COLORS[lags.indexOf(parseInt(key))]}/>)
        }
      }
      const cum_returns = (
        <LineChart width={900} height={400} data={cum_returns_by_lag_formatted} margin={{top: 5, right: 30, left: 20, bottom: 5}}>
          <XAxis allowDataOverflow dataKey="ts" domain={[(startDate? startDate : "dataMin"), (endDate? endDate : "dataMax")]} type="number" tickFormatter={v => (new Date(v)).toLocaleDateString("en-US")}/>
          <CartesianGrid strokeDasharray="3 3" />
          <YAxis allowDataOverflow domain={getAxisYDomain(startDate, endDate)} type="number" label={{ value: "Cummulative Returns", angle: -90, position: 'insideLeft' }} tickFormatter={v => v.toFixed(2)}/>
          <Tooltip labelFormatter={v => (new Date(v)).toLocaleDateString("en-US")}/>
          <Legend layout="vertical" align="right" verticalAlign="top" wrapperStyle={{"paddingLeft": "10px"}}/>
          <ReferenceLine y={0} />
          { cum_returns_lines }
        </LineChart>
      );

      let sharpeRows = []
      //We always put the average at the top of the table, before the normal lags
      sharpeRows.push(
        <TableRow key="sharpe_Average" style={{ "height": "35.2px" }}>      
          <TableCell>Average</TableCell>
          <TableCell>{`${avg_sharpe.toFixed(2)}`}</TableCell>
        </TableRow>
      )
      for (const sharpe_el in sharpe_ratio_by_lag_in_range){
        if (sharpe_ratio_by_lag_in_range[sharpe_el].lag !== "Average") {
          sharpeRows.push(
            <TableRow key={`sharpe_${sharpe_ratio_by_lag_in_range[sharpe_el].lag}`} style={{ "height": '35.2px' }}>      
              <TableCell>{sharpe_ratio_by_lag_in_range[sharpe_el].lag}</TableCell>
              <TableCell>{`${Number(sharpe_ratio_by_lag_in_range[sharpe_el][feature]).toFixed(2)}`}</TableCell>
            </TableRow>
          )
        }
      }

      const sharpe_table = (
        <Paper className={classes.paper}>
          <Table className={classes.table} size="small" aria-label="simple table">
            <TableHead>
              <TableRow key="sharpe_header" style={{ "height": '35.2px' }}>
                <TableCell>Lag</TableCell>
                <TableCell>Sharpe</TableCell>
              </TableRow>
            </TableHead>
            <TableBody>
              {sharpeRows}
            </TableBody>
          </Table>
        </Paper>
      )

      if (rows.length > 0){
        rows.push(<Divider key={`${feature}_divider`}/>);
      }

      rows.push(
        <Grid container spacing={3} key={feature} style={{ "marginBottom": "50px" }}>
          <Grid item xs={12} style={{ "marginBottom": "-20px" }}>
            <h3>Feature: {feature}</h3>
          </Grid>
          <Grid item xs={3}>
            {sharpe_table}
          </Grid>
          <Grid item xs={9}>
            <Paper className={classes.paper}>
              {cum_returns}
              {composite_sharp}
            </Paper>
          </Grid>
        </Grid>
      );
    }

    description = description_elements.join(", ")
  
    content = (
      <>
        <h2>
          Orthogonal model run <Link to={`/orthogonal/config/${id}`}>{id}</Link>
        </h2>
        {rows}
      </>
    );
  }

  return (
    <div className={classes.root}>
      <MetaTags>
        <title>Orthogonal Model Report</title>
        <meta property="og:description" content={description} />
        <meta property="og:image" content={Logo} />
      </MetaTags>
      {content}
    </div>
  );
}
