/*
 * Copyright (C) 2023 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.federatedcompute.services.data;

import static com.android.adservices.service.stats.AdServicesStatsLog.AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE;
import static com.android.adservices.service.stats.AdServicesStatsLog.AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE;
import static com.android.federatedcompute.services.data.FederatedTraningTaskContract.FEDERATED_TRAINING_TASKS_TABLE;

import android.annotation.NonNull;
import android.annotation.Nullable;
import android.content.ContentValues;
import android.content.Context;
import android.database.Cursor;
import android.database.SQLException;
import android.database.sqlite.SQLiteDatabase;
import android.database.sqlite.SQLiteException;

import com.android.federatedcompute.internal.util.LogUtil;
import com.android.federatedcompute.services.data.FederatedTraningTaskContract.FederatedTrainingTaskColumns;
import com.android.federatedcompute.services.statsd.ClientErrorLogger;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Iterables;

import java.util.ArrayList;
import java.util.List;

/** DAO for accessing training task table. */
public class FederatedTrainingTaskDao {

    private static final String TAG = FederatedTrainingTaskDao.class.getSimpleName();

    private final FederatedComputeDbHelper mDbHelper;
    private static volatile FederatedTrainingTaskDao sSingletonInstance;

    private FederatedTrainingTaskDao(FederatedComputeDbHelper dbHelper) {
        this.mDbHelper = dbHelper;
    }

    /** Returns an instance of the FederatedTrainingTaskDao given a context. */
    @NonNull
    public static FederatedTrainingTaskDao getInstance(Context context) {
        if (sSingletonInstance == null) {
            synchronized (FederatedTrainingTaskDao.class) {
                if (sSingletonInstance == null) {
                    sSingletonInstance =
                            new FederatedTrainingTaskDao(
                                    FederatedComputeDbHelper.getInstance(context));
                }
            }
        }
        return sSingletonInstance;
    }

    /** It's only public to unit test. */
    @VisibleForTesting
    public static FederatedTrainingTaskDao getInstanceForTest(Context context) {
        synchronized (FederatedTrainingTaskDao.class) {
            if (sSingletonInstance == null) {
                FederatedComputeDbHelper dbHelper =
                        FederatedComputeDbHelper.getInstanceForTest(context);
                sSingletonInstance = new FederatedTrainingTaskDao(dbHelper);
            }
            return sSingletonInstance;
        }
    }

    /** Deletes a training task in FederatedTrainingTask table. */
    private void deleteFederatedTrainingTask(String selection, String[] selectionArgs) {
        SQLiteDatabase db = mDbHelper.safeGetWritableDatabase();
        if (db == null) {
            ClientErrorLogger.getInstance()
                    .logError(
                            AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE,
                            AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE);
            return;
        }
        db.delete(FEDERATED_TRAINING_TASKS_TABLE, selection, selectionArgs);
    }

    /** Insert a training task or update it if task already exists. */
    public boolean updateOrInsertFederatedTrainingTask(FederatedTrainingTask trainingTask) {
        try {
            SQLiteDatabase db = mDbHelper.safeGetWritableDatabase();
            if (db == null) {
                return false;
            }
            return trainingTask.addToDatabase(db);
        } catch (SQLException e) {
            LogUtil.e(
                    TAG,
                    e,
                    "Failed to persist federated training task %s",
                    trainingTask.populationName());
            return false;
        }
    }

    /** Get the list of tasks that match select conditions. */
    @Nullable
    public List<FederatedTrainingTask> getFederatedTrainingTask(
            String selection, String[] selectionArgs) {
        SQLiteDatabase db = mDbHelper.safeGetReadableDatabase();
        if (db == null) {
            return null;
        }
        return FederatedTrainingTask.readFederatedTrainingTasksFromDatabase(
                db, selection, selectionArgs);
    }

    /** Delete a task from table based on job scheduler id. */
    public FederatedTrainingTask findAndRemoveTaskByJobId(int jobId) {
        String selection = FederatedTrainingTaskColumns.JOB_SCHEDULER_JOB_ID + " = ?";
        String[] selectionArgs = selectionArgs(jobId);
        FederatedTrainingTask task =
                Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null);
        try {
            if (task != null) {
                deleteFederatedTrainingTask(selection, selectionArgs);
            }
            return task;
        } catch (SQLException e) {
            LogUtil.e(TAG, e, "Failed to delete federated training task by job id %d", jobId);
            ClientErrorLogger.getInstance()
                    .logErrorWithExceptionInfo(
                            e,
                            AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE,
                            AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE);
            return null;
        }
    }

    /** Delete a task from table based on population name. */
    public FederatedTrainingTask findAndRemoveTaskByPopulationName(String populationName) {
        String selection = FederatedTrainingTaskColumns.POPULATION_NAME + " = ?";
        String[] selectionArgs = {populationName};
        FederatedTrainingTask task =
                Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null);
        try {
            if (task != null) {
                deleteFederatedTrainingTask(selection, selectionArgs);
            }
            return task;
        } catch (SQLException e) {
            LogUtil.e(
                    TAG,
                    e,
                    "Failed to delete federated training task by population name %s",
                    populationName);
            ClientErrorLogger.getInstance()
                    .logErrorWithExceptionInfo(
                            e,
                            AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE,
                            AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE);
            return null;
        }
    }

    /** Delete a task from table based on population name and calling package. */
    public FederatedTrainingTask findAndRemoveTaskByPopulationNameAndCallingPackage(
            String populationName, String callingPackage) {
        String selection =
                FederatedTrainingTaskColumns.POPULATION_NAME
                        + " = ? AND "
                        + FederatedTrainingTaskColumns.APP_PACKAGE_NAME
                        + " = ?";
        String[] selectionArgs = {populationName, callingPackage};
        FederatedTrainingTask task =
                Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null);
        try {
            if (task != null) {
                deleteFederatedTrainingTask(selection, selectionArgs);
            }
            return task;
        } catch (SQLException e) {
            LogUtil.e(
                    TAG,
                    e,
                    "Failed to delete federated training task by "
                            + "population name %s and calling package: %s",
                    populationName,
                    callingPackage);
            ClientErrorLogger.getInstance()
                    .logErrorWithExceptionInfo(
                            e,
                            AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE,
                            AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE);
            return null;
        }
    }

    /** Delete a task from table based on population name and owner Id (package and class name). */
    public FederatedTrainingTask findAndRemoveTaskByPopulationNameAndOwnerId(
            String populationName, String ownerPackage, String ownerClass, String ownerCertDigest) {
        String selection =
                FederatedTrainingTaskColumns.POPULATION_NAME
                        + " = ? AND "
                        + FederatedTrainingTaskColumns.OWNER_PACKAGE
                        + " = ? AND "
                        + FederatedTrainingTaskColumns.OWNER_CLASS
                        + " = ? AND "
                        + FederatedTrainingTaskColumns.OWNER_ID_CERT_DIGEST
                        + " = ?";
        String[] selectionArgs = {populationName, ownerPackage, ownerClass, ownerCertDigest};
        FederatedTrainingTask task =
                Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null);
        try {
            if (task != null) {
                deleteFederatedTrainingTask(selection, selectionArgs);
            }
            return task;
        } catch (SQLException e) {
            LogUtil.e(
                    TAG,
                    e,
                    "Failed to delete federated training task by population name %s and ATP: %s/%s",
                    populationName,
                    ownerPackage,
                    ownerClass);
            ClientErrorLogger.getInstance()
                    .logErrorWithExceptionInfo(
                            e,
                            AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE,
                            AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE);
            return null;
        }
    }

    /** Delete a task from table based on population name and job scheduler id. */
    public FederatedTrainingTask findAndRemoveTaskByPopulationAndJobId(
            String populationName, int jobId) {
        String selection =
                FederatedTrainingTaskColumns.POPULATION_NAME
                        + " = ? AND "
                        + FederatedTrainingTaskColumns.JOB_SCHEDULER_JOB_ID
                        + " = ?";
        String[] selectionArgs = {populationName, String.valueOf(jobId)};
        FederatedTrainingTask task =
                Iterables.getOnlyElement(getFederatedTrainingTask(selection, selectionArgs), null);
        try {
            if (task != null) {
                deleteFederatedTrainingTask(selection, selectionArgs);
            }
            return task;
        } catch (SQLException e) {
            LogUtil.e(
                    TAG,
                    e,
                    "Failed to delete federated training task by population name %s and job id %d",
                    populationName,
                    jobId);
            ClientErrorLogger.getInstance()
                    .logErrorWithExceptionInfo(
                            e,
                            AD_SERVICES_ERROR_REPORTED__ERROR_CODE__DELETE_TASK_FAILURE,
                            AD_SERVICES_ERROR_REPORTED__PPAPI_NAME__FEDERATED_COMPUTE);
            return null;
        }
    }

    /** Returns number of tasks already belongs to given owners package. */
    public int getTotalTrainingTaskPerOwnerPackage(String packageName) {
        SQLiteDatabase db = mDbHelper.safeGetReadableDatabase();
        if (db == null) {
            return 0;
        }
        final String query =
                "SELECT COUNT(*) FROM "
                        + FEDERATED_TRAINING_TASKS_TABLE
                        + " WHERE "
                        + FederatedTrainingTaskColumns.OWNER_PACKAGE
                        + " = ?";
        try (Cursor cursor = db.rawQuery(query, new String[] {packageName})) {
            if (cursor.moveToFirst()) {
                return cursor.getInt(0);
            } else {
                return 0; // No matching tasks found
            }
        }
    }

    /** Insert a training task history record or update it if task already exists. */
    public boolean updateOrInsertTaskHistory(TaskHistory taskHistory) {
        try {
            SQLiteDatabase db = mDbHelper.safeGetWritableDatabase();
            ContentValues values = new ContentValues();
            values.put(TaskHistoryContract.TaskHistoryEntry.JOB_ID, taskHistory.getJobId());
            values.put(
                    TaskHistoryContract.TaskHistoryEntry.POPULATION_NAME,
                    taskHistory.getPopulationName());
            values.put(TaskHistoryContract.TaskHistoryEntry.TASK_ID, taskHistory.getTaskId());
            values.put(
                    TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_ROUND,
                    taskHistory.getContributionRound());
            values.put(
                    TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_TIME,
                    taskHistory.getContributionTime());
            values.put(
                    TaskHistoryContract.TaskHistoryEntry.TOTAL_PARTICIPATION,
                    taskHistory.getTotalParticipation());
            return db.insertWithOnConflict(
                            TaskHistoryContract.TaskHistoryEntry.TABLE_NAME,
                            null,
                            values,
                            SQLiteDatabase.CONFLICT_REPLACE)
                    != -1;
        } catch (SQLException e) {
            LogUtil.e(
                    TAG,
                    "Failed to update or insert task history %s %s",
                    taskHistory.getPopulationName(),
                    taskHistory.getTaskId());
        }
        return false;
    }

    /** Gets the list of task history based on provided job id, population name and task id. */
    public List<TaskHistory> getTaskHistoryList(int jobId, String populationName, String taskId) {
        return getTaskHistory(jobId, populationName, taskId, false);
    }

    /** Get the latest record of task history based on job id, population name and task name. */
    public TaskHistory getLatestTaskHistory(int jobId, String populationName, String taskId) {
        List<TaskHistory> taskList = getTaskHistory(jobId, populationName, taskId, true);
        return taskList.isEmpty() ? null : taskList.get(0);
    }

    private List<TaskHistory> getTaskHistory(
            int jobId, String populationName, String taskId, boolean latest) {
        SQLiteDatabase db = mDbHelper.safeGetReadableDatabase();
        String selection =
                TaskHistoryContract.TaskHistoryEntry.JOB_ID
                        + " = ? AND "
                        + TaskHistoryContract.TaskHistoryEntry.POPULATION_NAME
                        + " = ? AND "
                        + TaskHistoryContract.TaskHistoryEntry.TASK_ID
                        + " = ?";
        String[] selectionArgs = {String.valueOf(jobId), populationName, taskId};
        String orderBy = TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_TIME + " DESC";
        String[] projection = {
            TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_TIME,
            TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_ROUND,
            TaskHistoryContract.TaskHistoryEntry.TOTAL_PARTICIPATION
        };
        List<TaskHistory> taskList = new ArrayList<>();
        try (Cursor cursor =
                db.query(
                        TaskHistoryContract.TaskHistoryEntry.TABLE_NAME,
                        projection,
                        selection,
                        selectionArgs,
                        /* groupBy= */ null,
                        /* having= */ null,
                        /* orderBy= */ orderBy)) {
            while (cursor.moveToNext()) {
                long contributionTime =
                        cursor.getLong(
                                cursor.getColumnIndexOrThrow(
                                        TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_TIME));
                long contributionRound =
                        cursor.getLong(
                                cursor.getColumnIndexOrThrow(
                                        TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_ROUND));
                long totalParticipation =
                        cursor.getLong(
                                cursor.getColumnIndexOrThrow(
                                        TaskHistoryContract.TaskHistoryEntry.TOTAL_PARTICIPATION));
                taskList.add(
                        new TaskHistory.Builder()
                                .setJobId(jobId)
                                .setTaskId(taskId)
                                .setPopulationName(populationName)
                                .setContributionRound(contributionRound)
                                .setContributionTime(contributionTime)
                                .setTotalParticipation(totalParticipation)
                                .build());
                if (latest) {
                    cursor.close();
                    return taskList;
                }
            }
            cursor.close();
            return taskList;
        } catch (SQLiteException e) {
            LogUtil.e(TAG, e, "Failed to read TaskHistory db");
        }
        return null;
    }

    /** Batch delete expired task history records. */
    public int deleteExpiredTaskHistory(long deleteTime) {
        SQLiteDatabase db = mDbHelper.safeGetWritableDatabase();
        if (db == null) {
            throw new SQLiteException(TAG + ": Failed to open database.");
        }
        String whereClause = TaskHistoryContract.TaskHistoryEntry.CONTRIBUTION_TIME + " < ?";
        String[] whereArgs = {String.valueOf(deleteTime)};
        int deletedRows =
                db.delete(TaskHistoryContract.TaskHistoryEntry.TABLE_NAME, whereClause, whereArgs);
        LogUtil.d(TAG, "Deleted %d expired tokens", deletedRows);
        return deletedRows;
    }

    private String[] selectionArgs(Number... args) {
        String[] values = new String[args.length];
        for (int i = 0; i < args.length; i++) {
            values[i] = String.valueOf(args[i]);
        }
        return values;
    }
}