import { IconProp } from '@fortawesome/fontawesome-svg-core'
import { faBrain } from '@fortawesome/pro-regular-svg-icons'
import {
  faHatWizard,
  faRabbitRunning,
  faScaleBalanced,
} from '@fortawesome/pro-solid-svg-icons'
import { t } from '@lingui/macro'
import { without } from 'lodash'

import { config } from 'config'
import { SavedMedia, SavedMediaContext } from 'modules/api'
import { FeatureFlags, featureFlags } from 'modules/featureFlags'
import { ProductKey } from 'modules/monetization'

import { ImageErrorType } from '../components/AIGeneratedImages/hooks/hooks'
import { StylePresetId } from '../components/AIGeneratedImages/types'
import FluxLogo from '../providers/icons/bfl.png'
import GeminiLogo from '../providers/icons/google-gemini.svg'
import IdeogramLogo from '../providers/icons/ideogram.svg'
import LeonardoLogo from '../providers/icons/leonardo.png'
import OpenAILogo from '../providers/icons/openai.svg'
import PlaygroundLogo from '../providers/icons/playground.svg'
import RecraftLogo from '../providers/icons/recraft.svg'
import StabilityLogo from '../providers/icons/stability.svg'

const RECRAFT_PROPER_NOUN = 'Recraft'
export type AspectRatio = {
  key: AspectRatioKey
  name: () => string
  width: number
  height: number
}
export type AspectRatioKey = 'square' | 'landscape' | 'portrait'

export const DEFAULT_ASPECT_RATIO: AspectRatio = {
  key: 'square',
  name: () => t`Square`,
  width: 1024,
  height: 1024,
}

export type ImageGenerateModel =
  | 'stable-diffusion-xl-v1-0'
  | 'playground-2.5'
  | 'playground-3'
  | 'dall-e-3'
  | 'imagen-3-flash'
  | 'imagen-3-pro'
  | 'ideogram-v2-turbo'
  | 'ideogram-v2'
  | 'flux-1-schnell'
  | 'flux-1-pro'
  | 'flux-1-quick'
  | 'flux-1-ultra'
  | 'leonardo-phoenix'
  | 'recraft-v3'
  | 'recraft-v3-svg'
  | 'recraft-20b-icon'

const IdeogramModel = {
  creatorLabel: () => 'Ideogram',
  icon: faBrain,
  image: IdeogramLogo,
  flag: 'ideogramTurbo',
  aspectRatios: {
    square: {
      key: 'square',
      name: () => t`Square`,
      width: 1536,
      height: 1536,
    },
    landscape: {
      key: 'landscape',
      name: () => t`Landscape`,
      width: 1792,
      height: 1344,
    },
    portrait: {
      key: 'portrait',
      name: () => t`Portrait`,
      width: 1344,
      height: 1792,
    },
  },
  provider: 'ideogram',
} as const

// If image generation on a model fails, fall back to a model in the same tier, prioritized by quality/cost
const FREE_FALLBACK_MODELS: ImageGenerateModel[] = [
  'flux-1-schnell',
  'flux-1-quick',
  'imagen-3-flash',
]

const PLUS_FALLBACK_MODELS: ImageGenerateModel[] = [
  'flux-1-pro',
  'imagen-3-pro',
  'ideogram-v2-turbo',
]

export type ImageModelInfo = {
  label: () => string
  description: () => string
  creatorLabel: () => string // Display name for the creator of this model
  icon: IconProp // If image is provided it is uses. Icon is only used when there is no image
  image?: StaticImageData
  minProductTier?: ProductKey // Minimum plan required to use this model (free if not specified)
  flag?: keyof FeatureFlags // Determines if its enabled at all
  freeFlag?: keyof FeatureFlags // Determines if everoyne gets it
  disabledFlag?: keyof FeatureFlags // Determines if its disabled in favor of another version
  aspectRatios: Partial<Record<AspectRatioKey, AspectRatio>>
  provider: ImageGenerateProvider
  fallbackModels?: ImageGenerateModel[]
  needsPromptRewrite?: boolean
}

export const IMAGE_GENERATE_MODELS: Record<ImageGenerateModel, ImageModelInfo> =
  {
    'stable-diffusion-xl-v1-0': {
      label: () => 'Stable Diffusion XL',
      description: () => t`Faster images in a variety of styles`,
      creatorLabel: () => 'Stability AI',
      icon: faRabbitRunning,
      image: StabilityLogo,
      flag: 'sdxlModel',
      fallbackModels: without(FREE_FALLBACK_MODELS, 'stable-diffusion-xl-v1-0'),
      minProductTier: 'free',
      aspectRatios: {
        square: {
          key: 'square',
          name: () => t`Square`,
          width: 1024,
          height: 1024,
        },
        landscape: {
          key: 'landscape',
          name: () => t`Landscape`,
          width: 1216,
          height: 832,
        },
        portrait: {
          key: 'portrait',
          name: () => t`Portrait`,
          width: 896,
          height: 1152,
        },
      },
      provider: 'baseten',
    },
    'playground-2.5': {
      label: () => 'Playground 2.5',
      description: () =>
        t`Faster images with vivid colors, best for illustrations`,
      creatorLabel: () => 'Playground',
      icon: faRabbitRunning,
      image: PlaygroundLogo,
      flag: 'playgroundModel',
      fallbackModels: without(FREE_FALLBACK_MODELS, 'playground-2.5'),
      minProductTier: 'free',
      aspectRatios: {
        square: {
          key: 'square',
          name: () => t`Square`,
          width: 1024,
          height: 1024,
        },
        landscape: {
          key: 'landscape',
          name: () => t`Landscape`,
          width: 1216,
          height: 832,
        },
        portrait: {
          key: 'portrait',
          name: () => t`Portrait`,
          width: 896,
          height: 1152,
        },
      },
      provider: 'playground',
    },
    'playground-3': {
      label: () => 'Playground 3',
      description: () =>
        t`Best for detailed prompts, capable of text and people`,
      creatorLabel: () => 'Playground',
      icon: faRabbitRunning,
      image: PlaygroundLogo,
      flag: 'playground3',
      fallbackModels: without(PLUS_FALLBACK_MODELS, 'playground-3'),
      minProductTier: 'plus',
      aspectRatios: {
        square: {
          key: 'square',
          name: () => t`Square`,
          width: 1024,
          height: 1024,
        },
        landscape: {
          key: 'landscape',
          name: () => t`Landscape`,
          width: 1216,
          height: 832,
        },
        portrait: {
          key: 'portrait',
          name: () => t`Portrait`,
          width: 896,
          height: 1152,
        },
      },
      provider: 'playground',
    },
    'flux-1-schnell': {
      label: () => 'Flux Fast',
      description: () => t`Fastest model with good quality`,
      creatorLabel: () => 'Black Forest Labs',
      icon: faRabbitRunning,
      image: FluxLogo,
      flag: 'flux1Schnell',
      disabledFlag: 'flux1Quick',
      fallbackModels: without(FREE_FALLBACK_MODELS, 'flux-1-schnell'),
      aspectRatios: {
        square: {
          key: 'square',
          name: () => t`Square`,
          width: 1024,
          height: 1024,
        },
        landscape: {
          key: 'landscape',
          name: () => t`Landscape`,
          width: 1216,
          height: 832,
        },
        portrait: {
          key: 'portrait',
          name: () => t`Portrait`,
          width: 896,
          height: 1152,
        },
      },
      provider: 'baseten',
    },
    'flux-1-quick': {
      label: () => 'Flux Fast 1.1',
      description: () => t`Fastest model with good quality`,
      creatorLabel: () => 'Black Forest Labs',
      icon: faRabbitRunning,
      image: FluxLogo,
      flag: 'flux1Quick',
      fallbackModels: without(FREE_FALLBACK_MODELS, 'flux-1-quick'),
      aspectRatios: {
        square: {
          key: 'square',
          name: () => t`Square`,
          width: 1024,
          height: 1024,
        },
        landscape: {
          key: 'landscape',
          name: () => t`Landscape`,
          width: 1216,
          height: 832,
        },
        portrait: {
          key: 'portrait',
          name: () => t`Portrait`,
          width: 896,
          height: 1152,
        },
      },
      provider: 'flux',
    },
    'flux-1-pro': {
      label: () => 'Flux Pro',
      description: () => t`Professional quality people, faces, and text`,
      creatorLabel: () => 'Black Forest Labs',
      icon: faRabbitRunning,
      image: FluxLogo,
      flag: 'flux1Pro',
      minProductTier: 'plus',
      fallbackModels: without(PLUS_FALLBACK_MODELS, 'flux-1-pro'),
      aspectRatios: {
        square: {
          key: 'square',
          name: () => t`Square`,
          width: 1440,
          height: 1440,
        },
        landscape: {
          key: 'landscape',
          name: () => t`Landscape`,
          width: 1440,
          height: 960,
        },
        portrait: {
          key: 'portrait',
          name: () => t`Portrait`,
          width: 960,
          height: 1440,
        },
      },
      provider: 'flux',
    },
    'flux-1-ultra': {
      label: () => 'Flux Ultra',
      description: () => t`Highest resolution images with fine details`,
      creatorLabel: () => 'Black Forest Labs',
      icon: faRabbitRunning,
      image: FluxLogo,
      flag: 'flux1Ultra',
      minProductTier: 'pro',
      fallbackModels: without(PLUS_FALLBACK_MODELS, 'flux-1-ultra'),
      aspectRatios: {
        square: {
          key: 'square',
          name: () => t`Square`,
          width: 2048,
          height: 2048,
        },
        landscape: {
          key: 'landscape',
          name: () => t`Landscape`,
          width: 2752,
          height: 1536,
        },
        portrait: {
          key: 'portrait',
          name: () => t`Portrait`,
          width: 1536,
          height: 2752,
        },
      },
      provider: 'flux',
    },
    'imagen-3-flash': {
      label: () => 'Imagen 3 Fast',
      description: () =>
        t`Google's faster model, good for detailed instructions`,
      creatorLabel: () => 'Google',
      icon: faScaleBalanced,
      image: GeminiLogo,
      minProductTier: 'free',
      flag: 'imagenFlash',
      fallbackModels: without(FREE_FALLBACK_MODELS, 'imagen-3-flash'),
      aspectRatios: {
        square: {
          key: 'square',
          name: () => t`Square`,
          width: 1536,
          height: 1536,
        },
        landscape: {
          key: 'landscape',
          name: () => t`Landscape`,
          width: 1792,
          height: 1344,
        },
        portrait: {
          key: 'portrait',
          name: () => t`Portrait`,
          width: 1344,
          height: 1792,
        },
      },
      provider: 'google',
    },
    'imagen-3-pro': {
      label: () => 'Imagen 3',
      description: () =>
        t`Google's most advanced model, good for text and people`,
      creatorLabel: () => 'Google',
      icon: faScaleBalanced,
      image: GeminiLogo,
      minProductTier: 'plus',
      flag: 'imagen3',
      fallbackModels: without(PLUS_FALLBACK_MODELS, 'imagen-3-pro'),
      aspectRatios: {
        square: {
          key: 'square',
          name: () => t`Square`,
          width: 1536,
          height: 1536,
        },
        landscape: {
          key: 'landscape',
          name: () => t`Landscape`,
          width: 1792,
          height: 1344,
        },
        portrait: {
          key: 'portrait',
          name: () => t`Portrait`,
          width: 1344,
          height: 1792,
        },
      },
      provider: 'google',
    },
    'ideogram-v2-turbo': {
      ...IdeogramModel,
      label: () => 'Ideogram 2.0 Turbo',
      description: () => t`Fast and good for text`,
      flag: 'ideogram',
      minProductTier: 'plus',
      fallbackModels: without(PLUS_FALLBACK_MODELS, 'ideogram-v2-turbo'),
    },
    'ideogram-v2': {
      ...IdeogramModel,
      label: () => 'Ideogram 2.0',
      description: () => t`Best for text, high quality overall`,
      flag: 'ideogram',
      minProductTier: 'pro',
      fallbackModels: without(PLUS_FALLBACK_MODELS, 'ideogram-v2'),
    },
    'dall-e-3': {
      label: () => 'DALL·E 3',
      description: () =>
        t`OpenAI's most advanced model, high quality but slower`,
      creatorLabel: () => 'OpenAI',
      icon: faHatWizard,
      image: OpenAILogo,
      minProductTier: 'pro',
      flag: 'dalle3',
      fallbackModels: without(PLUS_FALLBACK_MODELS, 'dall-e-3'),
      aspectRatios: {
        square: {
          key: 'square',
          name: () => t`Square`,
          width: 1024,
          height: 1024,
        },
        landscape: {
          key: 'landscape',
          name: () => t`Landscape`,
          width: 1792,
          height: 1024,
        },
        portrait: {
          key: 'portrait',
          name: () => t`Portrait`,
          width: 1024,
          height: 1792,
        },
      },
      provider: 'azure',
    },
    'leonardo-phoenix': {
      label: () => 'Leonardo Phoenix',
      description: () => t`Good for creative styles and text`,
      creatorLabel: () => 'Canva',
      icon: faHatWizard,
      image: LeonardoLogo,
      flag: 'leonardoPhoenix',
      fallbackModels: without(PLUS_FALLBACK_MODELS, 'leonardo-phoenix'),
      minProductTier: 'plus',
      aspectRatios: {
        square: {
          key: 'square',
          name: () => t`Square`,
          width: 1024,
          height: 1024,
        },
        landscape: {
          key: 'landscape',
          name: () => t`Landscape`,
          width: 1376,
          height: 768,
        },
        portrait: {
          key: 'portrait',
          name: () => t`Portrait`,
          width: 832,
          height: 1248,
        },
      },
      provider: 'leonardo',
    },
    'recraft-v3': {
      label: () => RECRAFT_PROPER_NOUN,
      description: () => t`Good for artistic and creative styles`,
      creatorLabel: () => 'Recraft',
      icon: faRabbitRunning,
      image: RecraftLogo,
      flag: 'recraftModel',
      fallbackModels: without(PLUS_FALLBACK_MODELS, 'recraft-v3'),
      minProductTier: 'pro',
      aspectRatios: {
        square: {
          key: 'square',
          name: () => t`Square`,
          width: 1024,
          height: 1024,
        },
        landscape: {
          key: 'landscape',
          name: () => t`Landscape`,
          width: 1707,
          height: 1024,
        },
        portrait: {
          key: 'portrait',
          name: () => t`Portrait`,
          width: 1024,
          height: 1707,
        },
      },
      provider: 'recraft',
    },
    'recraft-v3-svg': {
      label: () => t`${RECRAFT_PROPER_NOUN} Vector Illustration`,
      description: () => t`For line art and engravings`,
      creatorLabel: () => 'Recraft',
      icon: faRabbitRunning,
      image: RecraftLogo,
      flag: 'recraftSvgModel',
      fallbackModels: without(PLUS_FALLBACK_MODELS, 'recraft-v3-svg'),
      minProductTier: 'pro',
      aspectRatios: {
        square: {
          key: 'square',
          name: () => t`Square`,
          width: 1024,
          height: 1024,
        },
        landscape: {
          key: 'landscape',
          name: () => t`Landscape`,
          width: 1707,
          height: 1024,
        },
        portrait: {
          key: 'portrait',
          name: () => t`Portrait`,
          width: 1024,
          height: 1707,
        },
      },
      provider: 'recraft',
    },
    'recraft-20b-icon': {
      label: () => t`${RECRAFT_PROPER_NOUN} Icon`,
      description: () => t`For icons or logos`,
      creatorLabel: () => 'Recraft',
      icon: faRabbitRunning,
      image: RecraftLogo,
      flag: 'recraftSvgModel',
      fallbackModels: without(PLUS_FALLBACK_MODELS, 'recraft-20b-icon'),
      minProductTier: 'pro',
      aspectRatios: {
        square: {
          key: 'square',
          name: () => t`Square`,
          width: 1024,
          height: 1024,
        },
        landscape: {
          key: 'landscape',
          name: () => t`Landscape`,
          width: 1707,
          height: 1024,
        },
        portrait: {
          key: 'portrait',
          name: () => t`Portrait`,
          width: 1024,
          height: 1707,
        },
      },
      provider: 'recraft',
    },
  } as const

export type ImageModelInfoWithKey = ImageModelInfo & { model: string }

export const orderedImageModelInfos = (): ImageModelInfoWithKey[] => {
  return (
    Object.keys(IMAGE_GENERATE_MODELS)
      .map((key: ImageGenerateModel) => ({
        model: key,
        ...IMAGE_GENERATE_MODELS[key],
      }))
      // Case insensitive sort in english because the model names are proper nouns and aren't localalized
      .sort((a, b) =>
        a.label().localeCompare(b.label(), 'en-US', { sensitivity: 'base' })
      )
  )
}

export const getImageModelInfo = (model: ImageGenerateModel) =>
  IMAGE_GENERATE_MODELS[model] ||
  IMAGE_GENERATE_MODELS[featureFlags.get('aiGeneratedImagesDefaultModel')]

export type ImageGenerateProvider =
  | 'baseten'
  | 'openai'
  | 'azure'
  | 'google'
  | 'playground'
  | 'ideogram'
  | 'flux'
  | 'leonardo'
  | 'recraft'

export type GenerateImageOptions = {
  interactionId: string
  prompt: string
  workspaceId: string
  themeId?: string
  upscaleFactor?: number
  // context is not here, fetchGenerateImage will set based on themeId and docId
  docId?: string
  model: ImageGenerateModel
  fallbackModel?: ImageGenerateModel
  context?: SavedMediaContext
  // dont export this year
  count: number
  width?: number
  height?: number
  // AI Image Generation
  stylePreset?: StylePresetId
  stylePrompt?: string
  //for theme generation use
  negative_prompt?: string
  steps?: number
  cfg_scale?: number
  rewrite?: boolean
}

export const fetchGenerateImage = async (
  options: Partial<GenerateImageOptions>
): Promise<SavedMedia[]> => {
  const { themeId, docId, ...rest } = options
  // set the context based on themeId or docId
  const contextObj: Pick<
    GenerateImageOptions,
    'docId' | 'themeId' | 'context'
  > = themeId
    ? {
        context: SavedMediaContext.Theme,
        themeId: options.themeId,
      }
    : docId
    ? {
        context: SavedMediaContext.Doc,
        docId: options.docId,
      }
    : {
        // we dont have either docId or themeId, respect the context
        // passed in
        context: options.context,
      }

  const defaultModel = featureFlags.get('aiGeneratedImagesDefaultModel')
  const model = options.model ?? defaultModel
  const url = `${
    config.API_HOST || 'https://api.gamma.app'
  }/media/images/generate`

  const fallbackModels =
    getImageModelInfo(model).fallbackModels?.filter(isImageModelAvailable) || []
  const fallbackModel: ImageGenerateModel | undefined = fallbackModels.includes(
    defaultModel
  )
    ? defaultModel
    : fallbackModels[0]

  const req = await fetch(url, {
    method: 'POST',
    headers: {
      'Content-Type': 'application/json',
    },
    body: JSON.stringify({
      model,
      ...rest,
      fallbackModel,
      ...contextObj,
    }),
    credentials: 'include',
  })
  if (!req.ok) {
    const json = await req.json()
    const message = `${json.error}: ${json.message}`
    if (json.code === 'prohibited_input' && json.categories) {
      throw new ProhibitedInputError({ categories: json.categories })
    }
    throw new GenerateImageError(message, json)
  }
  return req.json()
}

class GenerateImageError extends Error {
  response: any

  constructor(message: string, response) {
    super(message, response)
    this.response = response
  }
}
GenerateImageError.prototype.name = 'GenerateImageError'

export const ImageErrorMessages: Record<ImageErrorType, () => string> = {
  sexual: () =>
    t`This prompt was blocked because it could generate sexual imagery.`,
  violence: () =>
    t`This prompt was blocked because it could generate violent imagery.`,
  prohibited: () =>
    t`This prompt was blocked because it could generate inappropriate content.`,
}

export class ProhibitedInputError extends Error {
  messageTranslated: string
  code = 'prohibited_input'
  category: string

  constructor({ categories }: { categories: Record<ImageErrorType, any> }) {
    const message = `Cannot generate image, prohibited (reasons=${JSON.stringify(
      Object.keys(categories)
    )})`
    super(message)
    this.category = this.parseCategories(categories)
    this.messageTranslated =
      ImageErrorMessages[this.category]() || t`This prompt was blocked`
  }

  parseCategories(categories: Record<ImageErrorType, any>): string {
    if (categories.sexual || categories['sexual/minors']) {
      return 'sexual'
    } else if (categories.violence || categories['violence/graphic']) {
      return 'violence'
    } else {
      return 'prohbited'
    }
  }
}

export const isImageModelAvailable = (
  model: ImageGenerateModel | string
): model is ImageGenerateModel => {
  const modelInfo = IMAGE_GENERATE_MODELS[model]
  if (!modelInfo) return false
  if (
    modelInfo.flag &&
    !featureFlags.get(modelInfo.flag) &&
    model !== featureFlags.get('aiGeneratedImagesDefaultModel')
  ) {
    return false
  }
  if (modelInfo.disabledFlag && featureFlags.get(modelInfo.disabledFlag)) {
    return false
  }
  return true
}

export const getRequiredPlanForImageModel = (
  model: ImageGenerateModel | string
): ProductKey => {
  const modelInfo = IMAGE_GENERATE_MODELS[model]
  if (!modelInfo) return 'free'
  if (modelInfo.freeFlag && featureFlags.get(modelInfo.freeFlag)) {
    return 'free'
  }
  return modelInfo.minProductTier || 'free'
}
