/* * Copyright (C) 2020 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "run_tflite.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/nnapi/nnapi_implementation.h" #define LOG_TAG "NN_MPROC_STRESS" constexpr int kInvalidArguments = -1; enum Arguments : int { kArgModelPath = 1, kArgInputDataPath, kArgInputShape, kArgInputElementSize, kArgProcessCount, kArgThreadCount, kArgDurationSeconds, kArgTestName, kArgJustCompileModel, kArgProcessFailureRatePercent, kArgNnApiDeviceName, kArgMmapModel }; constexpr int kMandatoryArgsCount = 9; const char* kUsage = R"""(%s modelFileName inputDataFile inputShape inputElementByteSize procCount threadCount durationInSeconds testName justCompileModel [processFailureRate] [nnapiDeviceName] [mmapModel] where: inputShape comma separated list of integers (e.g. '1,224,224,3') justCompileModel: true/false) processFailureRate: 0 to 100 percent probability of having one of the client processes failing. Defaults to 0.) mmapModel: true/false select if the TFLite model should be memory mapped to the given file or created from program memory)"""; bool canReadInputFile(const char* path) { std::string modelFileName(path); std::ifstream fstream(modelFileName); std::stringstream readBuffer; readBuffer << fstream.rdbuf(); return fstream.good(); } bool readInputData(const char* inputDataFileName, std::vector input_shape, int inputElementSize, std::vector* result) { int inputElementCount = 1; std::for_each( input_shape.begin(), input_shape.end(), [&inputElementCount](int dimSize) { inputElementCount *= dimSize; }); size_t inputDataSizeBytes = inputElementCount * inputElementSize; std::ifstream dataFile; dataFile.open(inputDataFileName); if (!dataFile) { return false; } std::function failToGenerateData = [](uint8_t*, size_t) { return false; }; while (!dataFile.eof()) { std::unique_ptr dataBuffer = std::make_unique(inputDataSizeBytes); if (!dataFile.read(reinterpret_cast(dataBuffer.get()), inputDataSizeBytes)) { break; } InferenceInOut entry{ dataBuffer.release(), inputDataSizeBytes, {}, failToGenerateData}; result->push_back({entry}); } return result; } bool runModel(const char* modelFileName, const std::vector& data, int durationSeconds, const std::string& nnApiDeviceName, bool justCompileModel, bool mmapModel) { if (justCompileModel) { std::time_t startTime = std::time(nullptr); while (std::difftime(std::time(nullptr), startTime) < durationSeconds) { int nnapiErrno = 0; std::unique_ptr model(BenchmarkModel::create( modelFileName, /*useNnApi=*/true, /*enableIntermediateTensorsDump=*/false, &nnapiErrno, nnApiDeviceName.empty() ? nullptr : nnApiDeviceName.c_str(), mmapModel, /*nnapi_cache_dir=*/nullptr)); if (!model) { __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Error creating model"); return false; } // sleeping from 300ms to 800ms constexpr int kMinPauseMs = 300; constexpr int kMaxPauseMs = 800; int sleepForMs = kMinPauseMs + (drand48() * (kMaxPauseMs - kMinPauseMs)); usleep(sleepForMs * 1000); } return true; } else { int nnapiErrno = 0; std::unique_ptr model(BenchmarkModel::create( modelFileName, /*useNnApi=*/true, /*enableIntermediateTensorsDump=*/false, &nnapiErrno, nnApiDeviceName.empty() ? nullptr : nnApiDeviceName.c_str(), mmapModel, /*nnapi_cache_dir=*/nullptr)); if (!model) { __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Error creating model"); return false; } std::vector result; constexpr int flags = FLAG_DISCARD_INFERENCE_OUTPUT | FLAG_IGNORE_GOLDEN_OUTPUT; return model->benchmark(data, std::numeric_limits::max(), durationSeconds, flags, &result); } } bool getBooleanArg(int argc, char* argv[], int argIndex, bool defaultValue) { if (argc > argIndex) { std::string argAsString(argv[argIndex]); return argAsString == "true"; } else { return defaultValue; } } int getIntArg(int argc, char* argv[], int argIndex, int defaultValue) { if (argc > argIndex) { return std::atoi(argv[argIndex]); } else { return defaultValue; } } int main(int argc, char* argv[]) { if (argc < kMandatoryArgsCount) { __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, kUsage, kMandatoryArgsCount, argc, argv[0]); return kInvalidArguments; } const char* modelFileName = argv[kArgModelPath]; const char* inputDataFileName = argv[kArgInputDataPath]; const char* testName = argv[kArgTestName]; std::string nnApiDeviceName{ argc > kArgNnApiDeviceName ? argv[kArgNnApiDeviceName] : ""}; int numProcesses = getIntArg(argc, argv, kArgProcessCount, 0); int numThreads = getIntArg(argc, argv, kArgThreadCount, 0); int durationSeconds = getIntArg(argc, argv, kArgDurationSeconds, 0); bool justCompileModel = getBooleanArg(argc, argv, kArgJustCompileModel, false); std::vector inputShape; std::istringstream inputShapeStream(argv[kArgInputShape]); std::string currSizeToken; while (std::getline(inputShapeStream, currSizeToken, ',')) { inputShape.push_back(std::stoi(currSizeToken)); } int inputElementSize = getIntArg(argc, argv, kArgInputElementSize, 0); int processFailureRate = getIntArg(argc, argv, kArgProcessFailureRatePercent, 0); bool mmapModel = getBooleanArg(argc, argv, kArgMmapModel, true); // Validate params if (!canReadInputFile(modelFileName)) { __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Error reading model file '%s'", modelFileName); return kInvalidArguments; } std::vector inputData; if (!justCompileModel) { if (!readInputData(inputDataFileName, inputShape, inputElementSize, &inputData)) { __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Error reading input data file '%s'", inputDataFileName); return kInvalidArguments; } } if (numProcesses <= 0 || numThreads <= 0 || durationSeconds <= 0 || inputElementSize <= 0) { __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, "Invalid arguments"); return kInvalidArguments; } __android_log_print( ANDROID_LOG_INFO, LOG_TAG, "Test '%s': running %s of model at path '%s' with input shape [%s] " "(element data size %d)," " %d processes of %d threads each using device '%s' for %d seconds", testName, justCompileModel ? "compilation only" : "full inference", modelFileName, argv[kArgInputShape], inputElementSize, numProcesses, numThreads, nnApiDeviceName.empty() ? "no-device" : nnApiDeviceName.c_str(), durationSeconds); srand48(time(NULL) + getpid()); std::vector children; pid_t pid = 1; bool forkSucceeded = true; bool isSubprocess = false; for (int i = 0; i < numProcesses; i++) { if (pid != 0) { pid = fork(); if (pid > 0) { children.push_back(pid); __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "Forked child pid %d", pid); } else if (pid < 0) { forkSucceeded = false; break; } else { isSubprocess = true; } } } if (isSubprocess) { __android_log_print( ANDROID_LOG_INFO, LOG_TAG, "%s model '%s': for %d seconds on device '%s' on %d threads", justCompileModel ? "Compiling" : "Running", modelFileName, durationSeconds, nnApiDeviceName.empty() ? "no-device" : nnApiDeviceName.c_str(), numThreads); bool shouldKillProcess = (drand48() * 100) <= (double)processFailureRate; if (shouldKillProcess) { float killAfter = durationSeconds * drand48(); __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "This process will be killed in %f seconds", killAfter); std::thread killer = std::thread([killAfter]() { usleep(killAfter * 1000.0 * 1000); __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "Killing current test process."); kill(getpid(), 9); }); killer.detach(); } std::vector threads; threads.reserve(numThreads); for (int i = 0; i < numThreads; i++) { threads.push_back(std::thread([&]() { runModel(modelFileName, inputData, durationSeconds, nnApiDeviceName, justCompileModel, mmapModel); })); } std::for_each(threads.begin(), threads.end(), [](std::thread& t) { t.join(); }); } else { for (auto pid : children) { waitpid(pid, nullptr, 0); } } __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "Test '%s': %s returning ", testName, isSubprocess ? "Test process" : "Main process"); return 0; }