import _ from 'lodash'
import { ASTKinds, type Modifier } from './dice'
import { type DiceContext, type DicePart, type SingleDiceResult, evaluate } from './diceEval'

export const d4 = () => dX(4)
export const d6 = () => dX(6)
export const d8 = () => dX(8)
export const d10 = () => dX(10)
export const d12 = () => dX(12)
export const d20 = () => dX(20)
export const d100 = () => dX(100)
export const dX = (x: number) => Math.floor(x * Math.random()) + 1

export type Die = {
  sides: number
  value: number
  tags: string[]
}

type Dice = {
  dice: Die[]
  mods: Modifier[]
  ctx: DiceContext
}
function mkDie(sides: number, tags: string[] = []) {
  return { sides, value: 0, tags }
}

type DiceModifier = {
  alterDice(dice: Dice): void
  alterRoll(dice: Dice, die: Die): void
  alterResult(dice: Dice): void
  alterAverage(amount: number, sides: number, num: number): number
  toString(): string
}

const keepHighest = (amount: number = 1, sortDir: number = 1) => ({
  alterDice(expr: Dice) {
    // generate more dice so we always have one more than we are keeping
    if (expr.dice.length <= amount) {
      const sides = expr.dice[0].sides
      const newDice = _.range(amount - expr.dice.length + 1).map((_) => mkDie(sides))
      expr.dice = expr.dice.concat(newDice)
    }
  },
  alterRoll() {},
  alterResult(expr: Dice) {
    // dont double-ignore rerolled dice
    const sorted = [...expr.dice.filter((die) => !die.tags.includes('ignored'))].sort(
      (a: Die, b: Die) => sortDir * (b.value - a.value)
    )
    sorted.slice(amount).forEach((die) => die.tags.push('ignored'))
  },
  alterAverage(_amount: number, _sides: number, num: number): number {
    return num
  },
  toString() {
    const direction = sortDir > 0 ? 'h' : 'l'
    return `${direction}${amount}`
  },
})

const modifiers: { [key: string]: (arg?: number) => DiceModifier } = {
  // crit, double all dice amounts
  c: () => ({
    alterDice(expr: Dice) {
      const sides = expr.dice[0].sides
      const critDice = _.range(expr.dice.length).map((_) => mkDie(sides, ['crit']))
      expr.dice.push(...critDice)
    },
    alterRoll() {},
    alterResult() {},
    alterAverage(_amount: number, _sides: number, num: number): number {
      return num
    },
    toString() {
      return `c`
    },
  }),
  // keep highest
  h: keepHighest,
  // keep lowest
  l: (amount) => keepHighest(amount, -1),
  // advantage
  a: () => keepHighest(1),
  // disadvantage
  d: () => keepHighest(1, -1),
  // explode, ie fluxfiber conduit
  x: (explodeAbove?: number) => ({
    alterDice() {},
    alterRoll(expr: Dice, die: Die) {
      if (die.value >= (explodeAbove || die.sides)) {
        die.tags.push('explode')
        expr.dice.push(mkDie(die.sides, ['explode']))
      }
    },
    alterResult() {},
    alterAverage(_amount: number, _sides: number, num: number): number {
      return num
    },
    toString() {
      return `x`
    },
  }),
  // reroll below, ie great weapon fighting
  r: (rerollBelow: number = 1) => ({
    alterDice() {},
    alterRoll(expr: Dice, die: Die) {
      if (die.value <= rerollBelow && !die.tags.includes('reroll')) {
        die.tags.push('ignored', 'rerolled')
        expr.dice.push(mkDie(die.sides, ['reroll']))
        // expr.addDie(new Dice(die.sides, ['explode']))
      }
    },
    alterResult() {},
    alterAverage(_amount: number, _sides: number, num: number): number {
      return num
    },
    toString() {
      return `r${rerollBelow}`
    },
  }),
  // sort results
  s: () => ({
    alterDice() {},
    alterRoll() {},
    alterResult(expr: Dice) {
      expr.dice.sort((a, b) => a.value - b.value)
    },
    alterAverage(_amount: number, _sides: number, num: number): number {
      return num
    },
    toString() {
      return `s`
    },
  }),
  // maximise the first num dice
  m: (maxNum: number = -1) => ({
    alterDice() {},
    alterRoll(_expr: Dice, die: Die) {
      if (maxNum === 0) return
      die.value = die.sides
      die.tags.push('maximized')
      maxNum--
    },
    alterResult() {},
    alterAverage(amount: number, sides: number, _num: number): number {
      if (maxNum === -1) return amount * sides
      return maxNum * sides + Math.ceil(((amount - maxNum) * (sides + 1)) / 2)
    },
    toString() {
      return `m${maxNum === -1 ? '' : maxNum}`
    },
  }),
  // player hp: first die is maximized, rest are taken the rounded up average
  p: (amount: number = 1) => ({
    alterDice() {},
    alterRoll(_expr: Dice, die: Die) {
      if (amount <= 0) return
      die.value = die.sides
      die.tags.push('maximized')
      amount--
    },
    alterResult() {},
    alterAverage(amount: number, sides: number, _num: number): number {
      return sides + (amount - 1) * Math.ceil((sides + 1) / 2)
    },
    toString() {
      return ''
    },
  }),
  // keen: crit limit is reduced by amount
  k: (amount: number = 1) => ({
    alterDice() {},
    alterRoll(_expr: Dice, die: Die) {
      if (amount <= 0) return
      if (die.value >= die.sides - amount) die.tags.push('maximum')
    },
    alterResult() {},
    alterAverage(amount: number, sides: number, _num: number): number {
      return sides + (amount - 1) * Math.ceil((sides + 1) / 2)
    },
    toString() {
      return ''
    },
  }),
}

function computeMods(mods: Modifier[], ctx: DiceContext): DiceModifier[] {
  const extraMods: Modifier[] = ctx.extraMods.map(({ type, arg }) => ({
    kind: ASTKinds.Modifier,
    type,
    arg: arg
      ? {
          kind: ASTKinds.Constant,
          value: `${arg}`,
        }
      : null,
  }))

  return [...mods, ...ctx.globalMods, ...extraMods]
    .filter((mod) => !!modifiers[mod.type])
    .map((mod) => {
      if ('arg' in mod && mod.arg) {
        const arg = evaluate(mod.arg, ctx).rolled
        return modifiers[mod.type](arg)
      } else {
        return modifiers[mod.type]()
      }
    })
}

export function roll(amount: number, sides: number, modifiers: Modifier[], ctx: DiceContext): SingleDiceResult {
  const mods = computeMods(modifiers, ctx)

  // TODO: implement more of these
  const average = Math.floor((amount * (sides + 1)) / 2)
  const moddedAvg = mods.reduce((num, mod) => mod.alterAverage(amount, sides, num), average)

  const modsText = mods.map((m) => m.toString()).join('')

  const dice: Dice = {
    dice: _.range(0, amount).map((_) => ({ sides, value: 0, tags: [] })),
    mods: modifiers,
    ctx,
  }

  mods.forEach((mod) => mod.alterDice(dice))

  let index = 0
  // arcane loop because it modifies the underlying collection while iterating
  while (index < dice.dice.length && index < 100) {
    const die = dice.dice[index]
    die.value = dX(die.sides)
    mods.forEach((mod) => mod.alterRoll(dice, die))
    if (die.value === die.sides) die.tags.push('maximum')
    index++
  }

  mods.forEach((mod) => mod.alterResult(dice))

  const expression = toString(amount, sides, modsText)
  const resultExpr: DicePart[] = dice.dice.map((die) => ({ type: 'die', die }))
  return {
    rolled: _.sum(dice.dice.filter((d) => !d.tags.includes('ignored')).map((d) => d.value)),
    average: moddedAvg,
    result: resultExpr,
    resultSimple: resultExpr,
    constant: false,
    expression,
    simplified: expression,
  }
}

function toString(amount: number, sides: number, modsText: string): string {
  const base = `${amount}d${sides}${modsText}`

  return base.replaceAll(/[12]d20h1/g, 'd20a').replaceAll(/[12]d20l1/g, 'd20d')
}
