import KCALLMEngine
import json
import os
import sqlite3
import sqlite_vec
import numpy as np
import KCALLMCore
import time

from typing import List
import struct


def serialize_f32(vector: List[float]) -> bytes:
    """serializes a list of floats into a compact "raw bytes" format"""
    return struct.pack("%sf" % len(vector), *vector)

def updateBestSolutionsDB(spath, vpath, fromScratch = 0):

    sigs = {}
    with open(spath, 'r') as file:
        sigs = json.load(file)
        file.close()

    foods = []
    nbMax = 10
    nbRecords = 0
    for w in sigs:
        if fromScratch != 0 or not sigs[w]["s"]:

            # Add
            foods.append(w)

            # Records
            if len(foods) == nbMax:
                nbRecords += 1
                print("Record #" + str(nbRecords))
                embs = KCALLMEngine.generateEmbedding(foods, KCALLMCore.MODEL_MISTRAL_ONLINE)
                if len(embs) == nbMax:
                    for i in range(nbMax):
                        sigs[foods[i]]["s"] = embs[i]
                else:
                    print("ERROR in dimension")
                    exit(0)
                foods = []

    if len(foods) != 0:
        nbRecords += 1
        print("Record #" + str(nbRecords))
        embs = KCALLMEngine.generateEmbedding(foods, KCALLMCore.MODEL_MISTRAL_ONLINE)
        for i in range(len(foods)):
            sigs[foods[i]]["s"] = embs[i]

    # Record signature
    if nbRecords != 0:
        sjson = json.dumps(sigs)
        if os.path.exists(spath):
            os.remove(spath)
        with open(spath, 'w') as fdic:
            fdic.write(sjson)
            fdic.close()

    # Create DB
    # Connect to an SQLite database (or create a new one)
    if os.path.exists(vpath):
        os.remove(vpath)
    conn = sqlite3.connect(vpath)

    # Load the sqlite_vec extension
    conn.enable_load_extension(True)
    sqlite_vec.load(conn)
    conn.enable_load_extension(False)

    # Versioning
    sql_version, vec_version = conn.execute("select sqlite_version(), vec_version()").fetchone()
    print(f"sql_version={sql_version}, vec_version={vec_version}")

    # Create table header
    mdl = KCALLMCore.MODEL_MISTRAL_ONLINE
    conn.execute(f"CREATE VIRTUAL TABLE vec_items USING vec0(embedding float[{KCALLMCore.getSignatureSize(mdl)}])")
    conn.commit()

    # Create a table with name
    conn.execute(f"CREATE TABLE word_items (key INTEGER PRIMARY KEY, word TEXT)")
    conn.commit()

    # Enrich the table
    idx = 0
    for w in sigs:
        txt = w
        sig = sigs[w]["s"]
        conn.execute("INSERT INTO vec_items(rowid, embedding) VALUES (?, ?)", [idx, serialize_f32(sig)],)
        conn.execute("INSERT INTO word_items(rowid, word) VALUES (?, ?)", [idx, txt],)
        idx += 1
        if idx % 500 == 0:
            print("Record #" + str(idx))
        conn.commit()
    conn.close()

    return


def searchTheoricalSolutions(text, words):

    allWords = [text] + words
    if len(allWords) >= 2:
        embs = KCALLMEngine.generateEmbedding(allWords, KCALLMCore.MODEL_MISTRAL_ONLINE)
        emb = embs[0]

        # Compute distance / reference
        rw = []
        for i in range(1, len(allWords)):
            r = {}
            r["word"] = allWords[i]
            r["dist"] = computeDistance(emb, embs[i])
            rw.append(r)
        
        # Sort result
        srw = sorted(rw, key=lambda x: x['dist'])
        for rr in srw:
            print(f"Word: {rr['word']}    -    distance: {rr['dist']}")                     
                       
    return


def computeDistance(t1, t2):

    a = np.array(t1)
    b = np.array(t2)
    cosine_similarity = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
    cosine_distance = float(1 - cosine_similarity)

    return cosine_distance


def searchBestSolutions(text, vpath):

    # Open connection
    conn = sqlite3.connect(vpath)

    # Load the sqlite_vec extension
    conn.enable_load_extension(True)
    sqlite_vec.load(conn)
    conn.enable_load_extension(False)

    sig = KCALLMEngine.generateEmbedding(text, KCALLMCore.MODEL_MISTRAL_ONLINE)    

    query = """SELECT rowid, distance
FROM vec_items
WHERE embedding MATCH ? AND k = 5
ORDER BY distance;"""

    start = time.process_time()
    cursor = conn.execute( query , [serialize_f32(sig[0])])
    end = time.process_time()
    cpu_time = end - start
    print(f"--> CPU time in DB: {cpu_time:.4f} seconds")
    
    results = cursor.fetchall()

    # Dump
    wordTab = []
    for row in results:
        rowid, rowdist = row
        cursor = conn.execute(
            """
            SELECT
                rowid, word
            FROM word_items
            WHERE rowid = """ + str(rowid) 
        )
        wordResults = cursor.fetchall()

        for wordRow in wordResults:
            rowid, word = wordRow
            print(f"Word: {word}  -  dist: {rowdist}  -  row: {rowid}")
            res = {}
            res["word"] = word
            res["dist"] = rowdist
            wordTab.append(res)

    conn.close()
    
    return wordTab

#-------------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------------
# Test embedding DB creation
#-------------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------------
def runUpdateAndRunBasicTest(text, fullUpdate):

    # Init
    spath = os.path.join(os.getcwd(), "data", "embeddings", "foodWording.json")
    vpath = os.path.join(os.getcwd(), "data", "embeddings", "embeddingDB.sl3")
    #vpath = os.path.join(os.getcwd(), "__deriveddata__", "DerivedObjects", "Data", "embeddingDB.sl3")

    # Check
    if fullUpdate == 0:
        print("++++++++++++++++++++++++++++++++++++++++++++")
        print("     No forced embedding recreation")
        print("++++++++++++++++++++++++++++++++++++++++++++")
    else:
        rebuildAll = 0
        if fullUpdate == 2:
            rebuildAll = 1
        updateBestSolutionsDB(spath, vpath, rebuildAll)

    # Execute
    print("-------------------------------------")
    print("Reference text: " + text)
    print("-------------------------------------")
    print()
    print("Results with LOW signature dimension:")
    resTab = searchBestSolutions(text, vpath)

    # Check full signature
    print()
    print("Results with HIGH signature dimension:")
    words = []
    for r in resTab:
        words.append(r["word"])
    searchTheoricalSolutions(text, words)

    return

