import os
import copy
import json
import sqlite3
import time
import re

import xml.etree.ElementTree as ET

import KCADBProcessorUtilities as KCAProc
import KCALLMTrace as KCATrace
import KCALLMUtilities as KCAUtil
import KCALLMMainGenEmbedding
import KCALLMTrace


def extractFirstFiguresFromString(txt):
    
    numbers = [float(num) for num in re.findall(r'[\d.]+', txt)]
    if len(numbers) == 0:
        return 0
    else:
        return int(numbers[0])

def adjustQuantity(q):
    
    oq = q

    # Error LLM
    oq = oq.replace("\"", "'")
        
    return oq


def sortFood(f):

    # Position should be as low as possible
    l0 = f["posiNormName"]

    # Name should be short
    l1 = len(f["name"])

    # Comment should be short
    l2 = len(f["comment"])

    # Rank is key
    l3 = 1000000 - f["rank"]

    # Priority to focused GTIN
    l4 = 1
    if f["gtin"] == f["gtinRef"]:
        l4 = 0
    
    # Priority to short GTIN
    l5 = len(f["gtin"])
    if l5 > 0:
        l5 = 1
    return (l0, l1, l2, l3, l4)


def extractPortraitRobots(sols, dbPath, pictosPath):

    # Init
    portraitRobots = ""

    # Foodlist must be an array
    for sol in sols:
        fname, pr = getPortraitRobot(sol, dbPath, pictosPath)
        portraitRobots += 'For "' + fname + '", here are the nutrition values:\n'
        portraitRobots += pr
        portraitRobots += "\n"

    return portraitRobots


def getPortraitRobot(sol, dbPath, pictosPath):

    # Identify name
    fid = sol["id"]
    portraitRobot = ''
    foodName = ''

    # Foodlist must be an array
    if fid == "":
        return portraitRobot
    
    # Open or create a new SQLite database file
    conn = sqlite3.connect(dbPath)
    table = "KCALME_TABLE"

    # Create a cursor object to execute SQL queries
    dbCursor = conn.cursor()

    # Attributes
    param = ["V_Name", "V_Comment", "V_PackType", "V_Source", "V_GTIN", "V_Trademark", "V_ID", "V_CaloriePerUnit", "V_SaltPerUnit", "V_WaterRatio", "V_SugarPerUnit", "V_NutriScore", "V_Ecoscore", "V_Allergens", "V_AllergenTraces"]
    V_Name = param.index("V_Name")
    V_Comment = param.index("V_Comment")
    V_Source = param.index("V_Source")
    V_PackType = param.index("V_PackType")
    V_GTIN = param.index("V_GTIN")
    V_Trademark = param.index("V_Trademark")
    V_ID = param.index("V_ID")
    V_CaloriePerUnit = param.index("V_CaloriePerUnit")
    V_SaltPerUnit = param.index("V_SaltPerUnit")
    V_WaterRatio = param.index("V_WaterRatio")
    V_SugarPerUnit = param.index("V_SugarPerUnit")
    V_NutriScore = param.index("V_NutriScore")
    V_Ecoscore = param.index("V_Ecoscore")
    V_Allergens = param.index("V_Allergens")
    V_AllergenTraces = param.index("V_AllergenTraces")

    # Parameter
    statement = ""
    for p in param:
        if statement != "":
            statement += ","
        statement += p

    q = "SELECT " + statement + " FROM " + table + " WHERE V_ID = '" + fid + "'"

    # Query
    rows = []
    try:
        dbCursor.execute(q)

        # Extract result
        rows = dbCursor.fetchall()
    
    except sqlite3.Error as e:
        print(f"An error occurred: {e}")

    # Bizarre
    if len(rows) != 1:

        print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
        print("ERROR: UNKNOWN PRODUCT: ID=" + fid)
        print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

        return params

    # Parse
    pr = ''

    # Build portrait robot
    packType = rows[0][V_PackType]
    unitBase, volume, weight = getUnitBase(packType, pictosPath)
    foodName = rows[0][V_Name]
    pr += formatSlot('name', rows[0][V_Name], "")
    pr += formatSlot('GTIN', rows[0][V_GTIN], "")
    pr += formatSlot('brand', rows[0][V_Trademark], "")
    pr += formatSlot('calorie', rows[0][V_CaloriePerUnit], "Kcal per " + unitBase)
    if weight != 0:
        pr += formatSlot('reference weight for a unity', weight, "g")
    if volume != 0:
        pr += formatSlot('reference volume for a unity', volume, "ml")
    pr += formatSlot('salt', rows[0][V_SaltPerUnit], "g per " + unitBase)
    pr += formatSlot('sugar', rows[0][V_SugarPerUnit], "g per " + unitBase)
    pr += formatSlot('NutriScore', rows[0][V_NutriScore], "")
    pr += formatSlot('EcoScore', rows[0][V_Ecoscore], "")
    pr += formatSlot('allergens', rows[0][V_Allergens], "")
    pr += formatSlot('allergen traces', rows[0][V_AllergenTraces], "")
    dataSource = getDataSource(rows[0][V_Source])
    pr += formatSlot('data source', dataSource, "")

    portraitRobot = pr

    return foodName, portraitRobot


def getUnitBase(pack, pictosPath):
    unitPack = '100g'
    pp = pack.split(',')
    volume = 0
    weight = 0
    print(pack)
    for p in pp:
        if p != '' and len(p) >= 3:
            trig = p[0:3]
            configPath = os.path.join(pictosPath, trig, "config_fr.xml")
            with open(configPath, 'r') as file:

                # Read XML content
                txt = file.read()
                root = ET.fromstring(txt)

                # Quantity
                quantity = 0
                unit = ''
                dimensions = root.find('dimensions')
                if dimensions is not None:
                    quantity = int(dimensions.get('quantity'))
                    unit = dimensions.get('unit')
                
                # Settings
                if unit.lower() == 'ml':
                    unitPack = '100ml'
                    volume = quantity
                elif unit == 'PO':
                    trigw = trig + '.w'
                    try:
                        weight = int(p.replace(trigw, ''))
                    except:
                        qval = 100
                        print("ERROR with converion of " + p)
                    unitPack = '100g'

    return unitPack, volume, weight

def getDataSource(code):
    ds = code
    if code == 'CIQ':
        ds = 'Ciqual'
    elif code == 'KCA':
        ds = 'KcalMe'
    elif code == 'OFF':
        ds = 'Open Food Facts'

    return ds


def formatSlot(name, value, unit):
    sval = str(value)
    if sval != '':
        slot = name + ": " + sval + unit
    else:
        slot = name + ": none"
        
    slot += "\n"

    return slot


def getBestSolutions(iSols, dbPath, dbEmbeddingPath, jpic):

    # Init
    fsols = []

    # Check
    if not os.path.exists(dbPath):
        KCALLMTrace.TRACE_ERROR("The DB path " + dbPath + "does not exist")
        exit(0)
    if not os.path.exists(dbEmbeddingPath):
        KCALLMTrace.TRACE_ERROR("The embedding DB path " + dbEmbeddingPath + "does not exist (update it!)")
        exit(0)

    # Foodlist must be an array
    foodList = iSols
    if foodList == {}:
        return fsols
    if not isinstance(foodList, list):
        foodList = [iSols]
    
    # Open or create a new SQLite database file
    conn = sqlite3.connect(dbPath)
    table = "KCALME_TABLE"

    # Create a cursor object to execute SQL queries
    dbCursor = conn.cursor()

    # Attributes
    param = ["V_Name", "V_Comment","V_NormName", "V_NormComment", "V_PackType", "V_GTIN","V_GTINRef", "V_ID", "V_GlobalCount", "V_NormTrademark", "V_Trademark", "V_NormAggr"]
    V_Name = param.index("V_Name")
    V_Comment = param.index("V_Comment")
    V_NormName = param.index("V_NormName")
    V_NormComment = param.index("V_NormComment")
    V_PackType = param.index("V_PackType")
    V_GTIN = param.index("V_GTIN")
    V_GTINRef = param.index("V_GTINRef")
    V_ID = param.index("V_ID")
    V_GlobalCount = param.index("V_GlobalCount")
    V_Trademark = param.index("V_Trademark")
    V_NormTrademark = param.index("V_NormTrademark")
    V_NormAggr = param.index("V_NormAggr")

    # Parameter
    statement = ""
    for p in param:
        if statement != "":
            statement += ","
        statement += p

    # Execute a SELECT query
    for ff in foodList:

        # Reformat the result
        f = KCAUtil.ignorePrefix("food:", ff)

        print()
        print("----------- result to be analyzed -----------")
        print(f)

        # Semantic
        name = f["name"]

        # Quantity
        quantityB = ""
        try:
            quantityB = str(f["quantity"])
        except:
            pass
        quantityLLM = adjustQuantity(quantityB)
        quantityLem = KCAProc.normalizeQuantityProposition(quantityLLM)

        # Others
        brand = ""
        try:
            brand = f["brand"]
        except:
            pass

        company = ""
        try:
            company = f["company"]
        except:
            pass

        typef = ""
        try:
            typef = f["type"]
        except:
            pass
        
        timeM = ""
        try:
            timeM = f["time"]
        except:
            try:
                timeM = f["timeOfTheDay"]
            except:
                pass
        
        eventtime = "unknown"
        try:
            eventtime = f["event"]
        except:
            pass

        # Remove brand from name
        if name != None and brand != None and brand != "" and name.find(brand) != -1 and brand != name:
            if name != brand:
                name = name.replace(brand, "")
            else:
                brand = ""

        # Particular case (ex: 'un Danone' nature')
        #if name == brand and type != 'aliment':
        #    name = typef

        # Normalize
        normName = KCAProc.normalizeName(name)
        normBrand = KCAProc.normalizeBrand(brand)
        normCompany = KCAProc.normalizeName(company)

        #-----------------------------------
        # First try on trademark
        #-----------------------------------
        qBrand = normBrand.strip()
        qCompany = normCompany.strip()
        q = "SELECT " + statement + " FROM " + table + " WHERE V_NormName LIKE '%" + normName + "%'"
        if qBrand != '':
            q += " AND V_NormTrademark LIKE '%" + qBrand + "%' "
        else:
           q += " AND (V_NormTrademark = '' OR V_NormTrademark IS NULL)"

        print()
        print("First try:")
        print(q)
        print()
 
        # Query
        rows = []
        try:
            dbCursor.execute(q)

            # Extract result
            rows = dbCursor.fetchall()
        
        except sqlite3.Error as e:
            print(f"An error occurred: {e}")

        #-----------------------------------
        # Second try with aggregation
        #-----------------------------------
        if len(rows) == 0:

            # Search for closest name (embedding)
            resTab = KCALLMMainGenEmbedding.searchBestSolutions(name, dbEmbeddingPath, 5)
            if len(resTab) != 0:

                embWord = resTab[0]["word"]
                print("Found embedding word: " + embWord)
                q = "SELECT " + statement + " FROM " + table + " WHERE V_Name = '" + embWord + "'"

                dbCursor.execute(q)
                rows = dbCursor.fetchall()        

                print()
                print("Second try:")
                print(q)
                print()

            else:

                qBrand = normBrand.strip()
                q = "SELECT " + statement + " FROM " + table + " WHERE V_NormName LIKE '%" + normName + "%' AND V_NormAggr LIKE '%" + normName + "%' "

                dbCursor.execute(q)
                rows = dbCursor.fetchall()        

                print()
                print("Second try:")
                print(q)
                print()

            #-----------------------------------
            # Third try (desesperate)
            #-----------------------------------
            if len(rows) == 0 and company != "" and brand != "":

                q = "SELECT " + statement + " FROM " + table + " WHERE V_NormAggr LIKE '%" + normName + "%' AND V_NormAggr LIKE '%" + normBrand + "%' AND V_NormAggr LIKE '%" + normCompany + "%' "

                dbCursor.execute(q)
                rows = dbCursor.fetchall()        

                print()
                print("Third try:")
                print(q)
                print()

        # Parse
        sols = []
        for row in rows:
            sol = {}
            sol["name"] = row[V_Name]
            sol["normName"] = row[V_NormName]
            sol["comment"] = row[V_Comment]
            sol["normComment"] = row[V_NormComment]
            sol["rank"] = row[V_GlobalCount]
            sol["id"] = row[V_ID]
            sol["quantity"] = quantityLLM
            sol["quantityLem"] = quantityLem
            sol["pack"] = row[V_PackType].split(",")
            sol["type"] = typef
            sol["gtin"] = row[V_GTIN]
            sol["gtinRef"] = row[V_GTINRef]
            sol["brand"] = row[V_Trademark]
            sol["time"] = timeM
            sol["event"] = eventtime
            sol["serving"] = ""

            sol["posiNormName"] = row[V_NormName].find(normName) # Position is important

            # Add solution
            sols.append(sol)

        # Sort
        fsol = {}
        if len(sols) > 0:

            # Extract first solution
            sorted_sols = sorted(sols, key=sortFood, reverse=False)

            KCATrace.TRACE("")
            #KCATrace.TRACE("Solution index: " + str(idx))
            KCATrace.TRACE("------------- Found solution (max 20) --------------")
            nbs = 0
            for ss in sorted_sols:
                if nbs < 20:
                    KCATrace.TRACE(ss["name"] + " - " + ss["normName"] + " - " + ss["comment"] + " - " + ss["brand"] + " - " + str(ss["rank"])+ " - " + ss["gtin"]+ " - " + ss["gtinRef"]+ " - " + ss["id"])
                    nbs += 1
            KCATrace.TRACE("----------------------------------------------------")
            KCATrace.TRACE("")

            fsol = sorted_sols[0]
            refName = fsol["normName"]
            refComment = fsol["normComment"]
            foundPicto = False
            for ff in sorted_sols:
                if foundPicto == False:
                    curName = ff["normName"]
                    curComment = ff["normComment"]
                    if curName == refName: # Necessary condition
                        # Find a solution compliant with the picto
                        # Extract the write picto
                        packTab = ff["pack"].copy()
                        serving = getBestPicto(quantityLem, packTab, jpic)
                        if serving != "":
                            fsol = ff # Most relevant product
                            fsol["serving"] = serving
                            foundPicto = True
                        else:
                            foundPicto = False
                            KCATrace.TRACE("ERROR: no solution for picto in the first solution")
                    else:
                        break

            fsols.append(fsol)
                
        else:
            KCATrace.TRACE_ERROR("No solution for query: " + q)

        

    # Close DB
    dbCursor.close()
    conn.close()

    return fsols


def containUnit(txt):

    tt = txt.lower()
    if len(tt) < 6 and (tt.find('gr') != -1 or tt.find('ml') != -1 or tt.find('cl') != -1 or tt.find('l') != -1):
        return True

    return False


def getBestPicto(quantity, packTab, jpic):

    # First test
    pictoResult = identifyBestPicto(quantity, packTab, jpic)

    # Particular case (ex: 10 verres)
    if pictoResult == "":

        # Rebuild a new quantity (100%)
        wds = quantity.split(" ")
        nbw = len(wds)
        quantity2 = ""
        for i in range(nbw):
            if i == 0:
                quantity2 = "2"
            else:
                quantity2 += " " + wds[i]

        # Retry
        if quantity2 != "" and nbw > 0:
            pictoResult = identifyBestPicto(quantity2, packTab, jpic)
            numstr = wds[0]
            if numstr.isnumeric():
                num = int(numstr) * 100
                pictoResult = pictoResult.replace("200", str(num))
            else:
                pictoResult = ""

    return pictoResult


def identifyBestPicto(quantity, packTab, jpic):

    pictos = []
    if len(packTab) == 0:
        return ""

    # Format quantity
    found = False
    nq = quantity.lower().rstrip().lstrip()

    #===========================
    # Unit included in quantity
    #===========================
    if containUnit(nq) == True:
        qq = quantity.replace(" ", "")
        if qq != '':
            pictoResult = packTab[0] + "-" + qq

    #===========================
    # No unit in quantity
    #===========================
    else:

        # Trim & split
        allWords = nq.split(" ")
        nbw = len(allWords)
        if nbw == 0:
            return pictos
        tab = jpic["decisionTree"]
        
        # Search first word
        nextIDs = []
        currentID = -1
        msgErr = ""
        for slot in tab:
            if slot["name"] == allWords[0]:
                nextIDs = slot["next"]
                currentID = slot["id"]
                break
        idx = 0
        while idx < nbw:
            idx += 1
            if idx == nbw:
                if len(tab[currentID]["leaf"]) == 0:
                    msgErr = "Wrong quantity: '" + quantity + "'"
                    KCATrace.TRACE("ERROR: " + msgErr)
                elif found == True or nbw == 1:
                    pictos = tab[currentID]["leaf"].copy()
                    break
            else:
                cw = allWords[idx]
                for i in nextIDs:
                    if tab[i]["name"] == cw:
                        nextIDs = tab[i]["next"]
                        currentID = i
                        found = True
                        break

        # Merge with existing
        pictoResult = ""
        for p in packTab:
            if pictoResult == "":
                for pic in pictos:
                    if p[0:3] == pic[0:3]:
                        pictoResult = pic
                        break
        
    return pictoResult

