import pickle as pkl
import numpy as np
import faiss
import math
import timeit


class Faiss_cluster:
    
    @classmethod
    def compute_centroid(cls, cluster_vector, vector_post):
        """
            create centroid vector with new post
        """
        vectors = []
        vectors.append(cluster_vector)
        vectors.append(vector_post)
        vectors = np.asarray(vectors, dtype=np.float32)
        dimension = 300
        ncentroids = 1
        verbose = False
        kmeans = faiss.Kmeans(dimension, ncentroids, verbose=verbose, gpu=True)
        kmeans.train(np.ascontiguousarray(vectors))
        
        return kmeans.centroids[0]
    
    
    @classmethod
    def faiss_search_similarity(cls, vectors, vector_post):
        """
            INDEX by faiss to quick search similar vectors
        """


        vectors = vectors.astype('float32')
        dimension = 300
        quantiser = faiss.IndexFlatL2(dimension)  

        start = timeit.default_timer()
        index = faiss.IndexIVFFlat(quantiser, dimension, faiss.METRIC_L2)
        stop = timeit.default_timer()
        print('Time1: ', stop - start)  


        start = timeit.default_timer()
        res = faiss.StandardGpuResources()  # use a single GPU
        stop = timeit.default_timer()
        print('Time2: ', stop - start)

        start = timeit.default_timer()
        # make it into a gpu index
        gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index)
        stop = timeit.default_timer()
        print('Time3: ', stop - start)  


        start = timeit.default_timer()
        gpu_index_flat.train(np.ascontiguousarray(vectors))
        stop = timeit.default_timer()
        print('Time4: ', stop - start)


        start = timeit.default_timer()
        gpu_index_flat.add(np.ascontiguousarray(vectors))
        stop = timeit.default_timer()
        print('Time5: ', stop - start)  

        k = 1

        start = timeit.default_timer()
        D, I = gpu_index_flat.search(np.ascontiguousarray(vector_post), k)
        stop = timeit.default_timer()
        print('Time6: ', stop - start)  

        proba_sqrt = float(D[0][0] * 10)
        if proba_sqrt < 2.5:
            return I[0][0]
        else:
            return False
