import { removeExtraSpaces } from './removeExtraSpaces';
import { generateModelDescription } from './trigger/generateModelDescription';
import { getGenderAssociation } from './getGenderAssociation';
import { Gender_Association, tag_category_type, Modifier, Prisma } from '@acme/db';
import { prisma } from '@acme/db/client'
import { logger } from './logger';
import { TrainingMeta } from './types/trainingMeta';


export const processPrompt = async (prompt: string, negativePrompt: string, trainingMeta: TrainingMeta, modifiers?: { [key: string]: string } | undefined, enabledAge?: boolean) => {
 // Get model description
 const modelDescription = removeExtraSpaces(generateModelDescription(trainingMeta, { enabledAge: enabledAge, modifiers: modifiers }));

 // Replacing {model} with the actual model description
 prompt = replaceModelPlaceholder(prompt, modelDescription);
 negativePrompt = replaceModelPlaceholder(negativePrompt, modelDescription);

 // Identify all the wildcards in the prompt
 const tags = prompt.match(/{.*?}/g) ?? [];
 const tagNames = tags.map(tag => tag.slice(1, -1)); 

 // Fetching all the wildcards from the database
 const categories = await fetchWildcards(tagNames, getGenderAssociation(trainingMeta.gender)!);

 return { prompt, negativePrompt, categories,  modelDescriptor: modelDescription };
}


export const fetchWildcards = async (tagNamesOrCategories: string[], gender: string) => {
  logger.debug('Fetching modifiers', { tagNamesOrCategories, gender });

  const genderAssociation = getGenderAssociation(gender);
  
  const validCategories = tagNamesOrCategories.filter(item => 
    Object.values(tag_category_type).includes(item as tag_category_type)
  ) as tag_category_type[];
  logger.debug('Valid categories', { validCategories });

  const nonCategoryTags = tagNamesOrCategories.filter(item => 
    !Object.values(tag_category_type).includes(item as tag_category_type)
  );
  logger.debug('Non-category tags', { nonCategoryTags });

  const MAX_MODIFIERS_PER_TAG = 40; 


  let categoryModifiers: Modifier[] = [];
  const genderAssociations = genderAssociation
  ? [getGenderAssociation(genderAssociation), 'gender-neutral']
  : ['gender-neutral'];

  if (validCategories.length > 0) {
    categoryModifiers = await prisma.$queryRaw<Modifier[]>`
      WITH ranked_modifiers AS (
        SELECT *,
               ROW_NUMBER() OVER (PARTITION BY category, tag_name ORDER BY RANDOM()) as rn
        FROM "prompttemplates"."modifiers"
        WHERE category::text IN (${Prisma.join(validCategories)})
          AND gender_association::text IN (${Prisma.join(genderAssociations)})
      )
      SELECT * FROM ranked_modifiers
      WHERE rn <= ${MAX_MODIFIERS_PER_TAG}
    `;
  }


  const tagModifiers = await prisma.modifier.findMany({
    where: {
      tag_name: { in: nonCategoryTags },
      gender_association: {
        in: genderAssociation
          ? [genderAssociation, Gender_Association.gender_neutral]
          : [Gender_Association.gender_neutral]
      }
    }
  });

  logger.debug('Fetched modifiers', { categoryModifiersCount: categoryModifiers.length, tagModifiersCount: tagModifiers.length });
  return {categoryModifiers,tagModifiers};
};

// Ex: a photo of a {trigger} => a photo of a man TOK
const replaceModelPlaceholder = (text: string, trigger: string): string => {
  text = text.replaceAll('{trigger}', trigger);
  text = text.replaceAll('{model}', trigger);


  return text;
};