import os

import openai
from openai import OpenAI

import copy
import json
import argparse
import requests
import time
import uuid
import base64
import anthropic
# import numpy as np

import KCALLMPrompt
import KCALLMConst
import KCALLMCore
import KCALLMTrace as KCATrace

#from mistralai.client import MistralClient
from mistralai import Mistral
#from mistralai import Mistral
#from mistralai.models.chat_completion import ChatMessage



def makeTextJSONCompliant(originalResp):
    
    # Misc
    resp = originalResp
    
    print()
    print("----------------- Make it compliant ------------------")
    print(resp)
    print("------------------------------------------------------")

    # Remove RDF prefix
    resp = resp.replace("food:", "")
    resp = resp.replace("activity:", "")

    # JSON is embedded in ```json blabla ```
    idx1 = resp.find("```json")
    if idx1 >= 0:
        resp = resp[idx1+7:]
        idx2 = resp.find("```")
        if idx2 > 0:
            resp = resp[:idx2]
            idx31 = resp.find("{")
            idx32 = resp.find("[")
            idx3 = idx31
            if idx32 < idx31:
                idx3 = idx32
            if idx3 > 0:
                resp = resp[idx3:]

    # Other cases
    else:

        # Array JSON or Object JSON?
        idx1 = resp.find("[")
        idx2 = resp.find("{")
        idx3 = resp.find("]", idx1 + 1)
        isArray = False
        if idx1 >= 0 and idx1 < idx2 and idx2 < idx3:
            isArray = True

        # JSON object
        if isArray == False:
            nbc = len(resp)
            nbp = 0
            i = 0
            dstart = False
            i1 = -1
            i2 = -1
            while(i <nbc):
                s = resp[i]
                if s == '{':
                    nbp += 1
                    if i1 == -1:
                        i1 = i
                    dstart = True
                elif s == '}':
                    nbp -=1
                if nbp == 0 and dstart == True:
                    i2 = i
                    break
                i += 1
            resp = resp[i1:i2+1]

        # JSON array
        else:
            # Get last config of '}]'
            irdx1 = resp.rfind("]")
            irdx2 = resp.rfind("}")
            resp = resp[idx1:irdx1+1]

    # Some other cleansing
    resp = resp.replace("\n", "")
    resp = resp.replace("None", "null")
    #resp = resp.replace("'", '"')
    
    # Remove les patterns '//....\n'
    found = True
    while found == True:
        found = False
        idx1 = resp.find("//")
        if idx1 > 0:
            idx2 = resp.find("\n", idx1)
            if idx2 > 0:
                found = True
                resp = resp[0:idx1] + resp[idx2:len(resp)]


    jfresp = {}
    try:
        jfresp = json.loads(resp)

        #--------------------------------------------------------
        # Particular hallucination (attr prefix the array...)
        if not jfresp:
            jfresp = []
        elif not isinstance(jfresp, list):
            idx = 0
            attr = ""
            for aa in jfresp:
                attr = aa
                idx += 1
            if idx == 1 and attr != "intents":
                if attr == "name":
                    jfresp = [jfresp]
                else:
                    jfresp = jfresp[attr]
            else:
                print("ERROR: wrong object representation:")
                print(jfresp)
        #--------------------------------------------------------


        print()
        print("------------------------ After simplification ------------------------")
        jjj = json.dumps(jfresp, indent=4)
        print(jjj)
        print("----------------------------------------------------------------------")

    except:
        print()
        print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
        print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
        print("ERROR: impossible to parse [II]:")
        print(originalResp)
        print()
        print("The extracted string is " + resp)
        print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
        print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    
    return jfresp


def setJSONFormat(resp):
    
    # Misc
    jDefResp = []
    
    # Remove useless info at the beginning and at the end
    idx1 = resp.find("[")
    if idx1 >= 0:
        resp = resp[idx1:]
    else:
        resp = "[" + resp
    idx2 = resp.rfind("]")
    if idx2 >= 0:
        resp = resp[0:idx2+1]
    else:
        resp = resp + "]"

    if resp.find("un demi") == -1:
        resp = resp.replace("'demi", "'un demi")
    resp = resp.replace('null', '""')
    resp = resp.replace("\n", "")
    resp = resp.replace("Answer:", "")
    resp = resp.replace(".", "")
    resp = resp.replace("None", "\"\"")

    jresp = json.loads(resp)

    #-----------------------------
    # Check validity
    #-----------------------------
    # Check names
    try:
        for jj in jresp:
            name = jj['name']
    except:
        return jDefResp

    return jresp


def splitPrompt(prompt, model):

    # Init
    prompts = {}
        
    # Create prompt
    splitPrompt = KCALLMPrompt.getSplitPrompts(prompt)

    # Response
    jresp = runLLM(splitPrompt, "", model)

    # Format ouput
    try:
        prompts = jresp['response'].copy()
    except:
        pass
    
    return prompts



def getMistralModelList():

    model = KCALLMCore.MODEL_MISTRAL_ONLINE
    key = KCALLMCore.getKey(model)
    url = "https://api.mistral.ai/v1/models"
    headers = {"Content-Type": "application/json", \
                "Accept": "application/json", \
                "Authorization": "Bearer " + key \
                }
    resp = requests.get(url, headers=headers)
    jresp = json.loads(resp.text)
    data = jresp["data"]
    modelList = []
    for d in data:
        modelList.append(d["id"])

    return modelList


def getIntentsFromPrompt(intentList, prompt, image64, model):

    # Init
    intents = []

    # Particular easy case
    if image64 != '':
        intents.append(KCALLMConst.INTENT_IMAGE_TO_FOOD)
        return intents
        
    # Create prompt
    intentPrompt = KCALLMPrompt.getIntentDetectionPrompt(prompt, intentList)

    # Response
    jresp = runLLM(intentPrompt, "", model)

    # Format ouput
    try:
        intents = jresp['response']['intents'].copy()
    except:
        pass
    
    return intents


def runLLM(prompt, image64, mdl, isItJson = True):


    # Init
    rjson = {}
    tokenMaxSize = 3000
    model = mdl

    # Particular case
    if image64 != '':
        if model == KCALLMCore.MODEL_MISTRAL_ONLINE:
            model = KCALLMCore.MODEL_PIXTRAL_ONLINE
        print("##############################################################################################")
        print("#   For image extraction, " + model + " is used    #")
        print("##############################################################################################")
            
    #openai.api_key = os.getenv("OPENAI_API_KEY")
    key = KCALLMCore.getKey(model)
    openai.api_key = key
    
    prompt_tokens = 0
    completion_tokens = 0

    print()
    print("==================================== Prompt =============================================")
    print(prompt)
    print("=========================================================================================")
    print()

    response = ""
    if model == KCALLMCore.MODEL_GPT4:

        key = KCALLMCore.getKey(model)

        #-----------------------
        # TEXT
        #-----------------------
        if image64 == '':
            
            messages=[{"role": "user", "content" : prompt}]

            client = OpenAI(api_key = key)

##            for resp in client.chat.completions.create(
##                model = model,
##                max_tokens = tokenMaxSize,
##                messages = messages,
##                stream=True,
##                temperature=0.
##            ):
##                response += (resp.choices[0].delta.content or "")

            resp = client.chat.completions.create(
                model=model,
                messages=messages,
                stream=False,
                max_tokens = tokenMaxSize,
                temperature=0.
                )
            
            response = resp.choices[0].message.content
            
        #-----------------------
        # IMAGE
        #-----------------------
        else:

            print("Image recognition....")
            headers = {
              "Content-Type": "application/json",
              "Authorization": f"Bearer {key}"
            }

            payload = {
                "model": model,
                "messages": [
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": prompt
                            },
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/jpeg;base64,{image64}"
                                    }
                            }
                            ]
                    }
                ],
                "response_format": {
                    "type": "json_object"
                    },
                "max_tokens": 300
                }

            resp = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
            jresp = resp.json()

            response = jresp["choices"][0]["message"]["content"]
   
        
    elif model == KCALLMCore.MODEL_MISTRAL:

        messages=[{"role": "user", "content" : prompt}]
        client = OpenAI(base_url="http://px101.prod.exalead.com:8110/v1", api_key=key)
##        client.completions.create(
##            model=model,
##            prompt="Say this is a test",
##            max_tokens=500,
##            temperature=0.
##            )
        resp = client.chat.completions.create(
            model = model,
            messages = messages,
            response_format={ "type": "json_object" },
            max_tokens=1000,
            temperature=0.
            )
        response = resp.choices[0].message.content
 
            
    elif model == KCALLMCore.MODEL_MISTRAL_ONLINE or model == KCALLMCore.MODEL_PIXTRAL_ONLINE:

        key = KCALLMCore.getKey(model)

        if image64 == '':            

            
            client = Mistral(api_key=key)

            chat_response = client.chat.complete(
                model = model,
                response_format={ "type": "json_object" },
                messages = [
                    {
                        "role": "user",
                        "content": prompt,
                    },
                ]
            ) 

            # client = MistralClient(api_key=key)

            # chat_response = client.chat(
            #     model=model,
            #     response_format={ "type": "json_object" },
            #     messages=[ChatMessage(role="user", content=prompt)]
            # )

            response = chat_response.choices[0].message.content
            completion_tokens = chat_response.usage.completion_tokens
            prompt_tokens = chat_response.usage.prompt_tokens

            """
            messages=[{"role": "user", "content" : prompt}]
            url = "https://api.mistral.ai/v1/chat/completions"
            headers = {"Content-Type": "application/json", \
                    "Accept": "application/json", \
                    "Authorization": "Bearer " + key \
                    }
            data = { "model": model, \
                    "temperature": 0., \
                    "messages": messages \
                    }
            resp = requests.post(url, headers=headers, data=json.dumps(data))
            jresp = json.loads(resp.text)
            
            try:
                response = jresp["choices"][0]["message"]["content"]
                completion_tokens = jresp["usage"]["completion_tokens"]
                prompt_tokens = jresp["usage"]["prompt_tokens"]
            except:
                print("+++++++++++++++++++++++++++++++++++++++++")
                print("+++++++++++++++++++++++++++++++++++++++++")
                print("Impossible to parse:")
                print(jresp)
                print("+++++++++++++++++++++++++++++++++++++++++")
                print("+++++++++++++++++++++++++++++++++++++++++")
                return None
            """
        
        else:
       
            # Getting the base64 string
            headers = {
                "Content-Type": "application/json",
                "Authorization": f"Bearer {key}"
            }

            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": prompt
                        },
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{image64}"
                                }
                        }
                    ]
                }
            ]

            url = "https://api.mistral.ai/v1/chat/completions"
            headers = {"Content-Type": "application/json", \
                    "Accept": "application/json", \
                    "Authorization": "Bearer " + key \
                    }

            data = { "model": model, \
                    "temperature": 0., \
                    "messages": messages, \
                    "response_format": {"type": "json_object"} \
                    }

            resp = requests.post(url, headers=headers, data=json.dumps(data))
            jresp = json.loads(resp.text)
            response = jresp["choices"][0]["message"]["content"]



            # # Initialize the Mistral client
            # client = Mistral(api_key=key)

            # # Define the messages for the chat
            # messages = [
            #     {
            #         "role": "user",
            #         "content": [
            #             {
            #                 "type": "text",
            #                 "text": "What's in this image?"
            #             },
            #             {
            #                 "type": "image_url",
            #                 "image_url": {
            #                     "url": f"data:image/jpeg;base64,{image64}"
            #                     }
            #             }
            #         ]
            #     }
            # ]

            # # Get the chat response
            # chat_response = client.chat.complete(
            #     model=model,
            #     messages=messages
            # )

            # # Print the content of the response
            # print(chat_response.choices[0].message.content)



    elif model == KCALLMCore.MODEL_ANTHRO:

        key = KCALLMCore.getKey(model)


        client = anthropic.Anthropic(api_key=key)
        
        messages=[{"role": "user", "content" : [ { "type": "text", "text": prompt}]}]
        
        resp = client.messages.create(
            model=model,
            max_tokens=3000,
            temperature=0,
            system="",
            messages=messages
        )

        response = resp.content[0].text
        
    else:
        return None

    KCATrace.TRACE("")
    KCATrace.TRACE("------------------------------ LLM Raw response -----------------------------")
    KCATrace.TRACE(response)
    KCATrace.TRACE("-----------------------------------------------------------------------------")

    # Display result
    resp = []
    mess = ""

    # Parse result
    try:
        if isItJson == True:
            resp = makeTextJSONCompliant(response)
        else:
            resp = response

        # Null event
        try:
            if resp[0]['event'] == '':
                resp[0]['event'] = 'unknown'
        except:
            pass
    except:
        pass

    # Response
    rjson["response"] = copy.deepcopy(resp)

    # Pricing (GPT-4 32K context)
    input_token_pricing = 0.06 / 1000.0 # euro/1K tokens
    output_token_pricing = 0.12 / 1000.0 # euro/1K tokens
    rjson["cost"] = prompt_tokens * input_token_pricing + completion_tokens * output_token_pricing

    return rjson


def generateEmbedding(text, mdl):

    dim = KCALLMCore.getSignatureSize(mdl)
    key = KCALLMCore.getKey(mdl)
    embMdl = KCALLMCore.getModelForEmbedding(mdl)
    
    allEmbs = []
    if mdl == KCALLMCore.MODEL_GPT4:

        client = OpenAI(api_key=key)

        response = None
        if dim == 0:
            response = client.embeddings.create(
                input = text,
                model = embMdl
            )
        else:
            response = client.embeddings.create(
                input = text,
                dimensions = dim,
                model = embMdl
            )
        
        # Extract all
        for d in response.data:
            allEmbs.append(d.embedding)
    
    # A single dimension is proposed
    elif mdl == KCALLMCore.MODEL_MISTRAL_ONLINE:
    
        client = Mistral(api_key=key)

        response = client.embeddings.create(
            inputs=text,
            model=embMdl
        )
        
        # Extract all
        for d in response.data:
            allEmbs.append(d.embedding)


    return allEmbs

# def normalize_l2(x):
#     x = np.array(x)
#     if x.ndim == 1:
#         norm = np.linalg.norm(x)
#         if norm == 0:
#             return x
#         return x / norm
#     else:
#         norm = np.linalg.norm(x, 2, axis=1, keepdims=True)
#         return np.where(norm == 0, x, x / norm)



