import {AiProjectService} from "services/aiStudio/ai-project.service";
import {Injectable} from "@angular/core";
import {HttpClient} from "@angular/common/http";
import {Config} from "global/config";
import {Model} from "@teamviewer/aistudioapi-common-angular";
import {GROUND_TRUTH, IMAGE_SIZE_SMALL, PREDICTION} from 'utils/project-utils'

@Injectable()
export class TrainingService {
    constructor(private aiProjectService: AiProjectService, private http: HttpClient, private config: Config) {}

    showTrainingDetails = true;
    userPreference: any = null;

    startNewTraining(model: Model) {
        let projectUuid = model.aiProject?.projectUuid
        const endpoint = `${this.config.baseUrl}/aiTraining/${projectUuid}`;
        return this.http.post(endpoint, model).toPromise()
    };

    prepareMisclassifiedImages(model: any) {
        let misclassifiedImages: any = {};

        if(model && model.trainingResult && model.trainingResult.misClassifieds) {
            let projectUuid = this.aiProjectService.getAiProject().projectUuid;
            let misclassifiedClasses: any = [];

            // Get all the misclassified classes
            for(let misClassifiedImg of model.trainingResult.misClassifieds) {
                let prediction = misClassifiedImg.misclassifiedDataList.filter((mis: any) => mis.misclassifiedDataType === PREDICTION)[0]?.label;
                if(prediction && misclassifiedClasses.indexOf(prediction) === -1) {
                    misclassifiedClasses = misclassifiedClasses.concat(prediction);
                }
            }

            misclassifiedClasses.sort()

            // Add all the misclassified images to the object
            for(let misclassifiedClass of misclassifiedClasses) {
                let images = [];
                for(let misClassifiedImg of model.trainingResult.misClassifieds) {
                    let prediction = misClassifiedImg.misclassifiedDataList.filter((mis: any) => mis.misclassifiedDataType === PREDICTION)[0]?.label;

                    if(misclassifiedClass === prediction) {
                        let groundTruth = misClassifiedImg.misclassifiedDataList.filter((mis: any) => mis.misclassifiedDataType === GROUND_TRUTH)[0]?.label;
                        images.push({
                            name: misClassifiedImg.path,
                            url: projectUuid +'/datum/'+ misClassifiedImg.datumId + '/image?size=' + IMAGE_SIZE_SMALL,
                            groundTruth
                        });
                    }
                }
                misclassifiedImages[misclassifiedClass] = images;
            }
        }

        return misclassifiedImages;
    };

    renameTraining(name: string, model: Model): Promise<any> {
        model.name = name;
        return this.http.put(this.config.baseUrl + `/aiTraining/model/${model.modelUuid}`, model).toPromise();
    }

    getPublishedModel() {
        let models = this.aiProjectService.getAiProject().models || [];
        return models.filter((model: Model) => model.published);
    };

    publishModel(model: Model): Promise<any> {
        model.published = true;
        return this.http.put(this.config.baseUrl + `/aiTraining/model/${model.modelUuid}`, model).toPromise();
    };

    unpublishModel(model: Model): Promise<any> {
        model.published = false;
        return this.http.put(this.config.baseUrl + `/aiTraining/model/${model.modelUuid}`, model).toPromise()
    };

    deleteTraining(model: Model) {
        let trainingId = model.modelUuid
        const endpoint = `${this.config.baseUrl}/aiTraining/model/${trainingId}`
        return this.http.delete(endpoint).toPromise()
    }
}
