Skip to content

Commit

Permalink
Build light weight PyRuntime without llvm or onnx-mlir (onnx#3044)
Browse files Browse the repository at this point in the history
* pass test

Signed-off-by: Chen Tong <chentong@us.ibm.com>

* package

Signed-off-by: Chen Tong <chentong@us.ibm.com>

* clean makefile

Signed-off-by: Chen Tong <chentong@us.ibm.com>

* document

Signed-off-by: Chen Tong <chentong@us.ibm.com>

* fix MLIR.cmake

Signed-off-by: Chen Tong <chentong@us.ibm.com>

* fix script

Signed-off-by: Chen Tong <chentong@us.ibm.com>

* fix

Signed-off-by: Chen Tong <chentong@us.ibm.com>

* add comments

Signed-off-by: Chen Tong <chentong@us.ibm.com>

* LIGHT

Signed-off-by: Chen Tong <chentong@us.ibm.com>

---------

Signed-off-by: Chen Tong <chentong@us.ibm.com>
  • Loading branch information
chentong319 authored and christopherlmunoz committed Jan 30, 2025
1 parent 9d8898b commit e24fd5c
Show file tree
Hide file tree
Showing 23 changed files with 786 additions and 61 deletions.
57 changes: 35 additions & 22 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ option(ONNX_MLIR_ENABLE_STABLEHLO "Enable StableHLO support." ON)
option(ONNX_MLIR_ENABLE_WERROR "Enable warnings as errors." OFF)
option(ONNX_MLIR_SUPPRESS_THIRD_PARTY_WARNINGS "Suppress warning in third_party code." ON)
option(ONNX_MLIR_ENABLE_JAVA "Set to ON for building the Java runtime, tools, and tests" ON)
option(ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT "Set to ON for building Python driver of running the compiled model without llvm-project." OFF)

set(CMAKE_CXX_STANDARD 17)

Expand Down Expand Up @@ -73,8 +74,10 @@ set(ONNX_MLIR_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
set(ONNX_MLIR_VENDOR ${PACKAGE_VENDOR} CACHE STRING
"Vendor-specific text for showing with version information.")

include(CTest)
include(ExternalProject)
if(NOT ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT)
include(CTest)
include(ExternalProject)
endif()
include(MLIR.cmake)

# MLIR.cmake calls find_package(MLIR) which sets LLVM_MINIMUM_PYTHON_VERSION
Expand Down Expand Up @@ -159,23 +162,29 @@ endif()
set(CMAKE_MESSAGE_LOG_LEVEL NOTICE)

# Add third party subdirectories and define options appropriate to run their cmakes.
set(pybind11_FIND_QUIETLY ON)
add_subdirectory(third_party/onnx)
add_subdirectory(third_party/pybind11)
add_subdirectory(third_party/rapidcheck)
if (ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT)
add_subdirectory(third_party/onnx)
add_subdirectory(third_party/pybind11)
else()
set(pybind11_FIND_QUIETLY ON)
add_subdirectory(third_party/onnx)
add_subdirectory(third_party/pybind11)

if (ONNX_MLIR_ENABLE_STABLEHLO)
add_subdirectory(third_party/stablehlo EXCLUDE_FROM_ALL)
endif()
add_subdirectory(third_party/rapidcheck)

if (NOT TARGET benchmark)
set(BENCHMARK_USE_BUNDLED_GTEST OFF)
set(BENCHMARK_ENABLE_GTEST_TESTS OFF)
set(BENCHMARK_ENABLE_TESTING OFF)
set(BENCHMARK_ENABLE_WERROR OFF)
# Since LLVM requires C++11 (or higher) it is safe to assume that std::regex is available.
set(HAVE_STD_REGEX ON CACHE BOOL "OK" FORCE)
add_subdirectory(third_party/benchmark)
if (ONNX_MLIR_ENABLE_STABLEHLO)
add_subdirectory(third_party/stablehlo EXCLUDE_FROM_ALL)
endif()

if (NOT TARGET benchmark)
set(BENCHMARK_USE_BUNDLED_GTEST OFF)
set(BENCHMARK_ENABLE_GTEST_TESTS OFF)
set(BENCHMARK_ENABLE_TESTING OFF)
set(BENCHMARK_ENABLE_WERROR OFF)
# Since LLVM requires C++11 (or higher) it is safe to assume that std::regex is available.
set(HAVE_STD_REGEX ON CACHE BOOL "OK" FORCE)
add_subdirectory(third_party/benchmark)
endif()
endif()

# All libraries and executables coming from llvm or ONNX-MLIR have had their
Expand Down Expand Up @@ -207,8 +216,12 @@ if (ONNX_MLIR_ENABLE_STABLEHLO)
add_compile_definitions(ONNX_MLIR_ENABLE_STABLEHLO)
endif()

add_subdirectory(utils)
add_subdirectory(include)
add_subdirectory(src)
add_subdirectory(docs)
add_subdirectory(test)
if (ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT)
add_subdirectory(src)
else()
add_subdirectory(utils)
add_subdirectory(include)
add_subdirectory(src)
add_subdirectory(docs)
add_subdirectory(test)
endif()
54 changes: 32 additions & 22 deletions MLIR.cmake
Original file line number Diff line number Diff line change
@@ -1,33 +1,41 @@
# SPDX-License-Identifier: Apache-2.0

# Must unset LLVM_DIR in cache. Otherwise, when MLIR_DIR changes LLVM_DIR
# won't change accordingly.
unset(LLVM_DIR CACHE)
if (NOT DEFINED MLIR_DIR)
message(FATAL_ERROR "MLIR_DIR is not configured but it is required. "
"Set the cmake option MLIR_DIR, e.g.,\n"
" cmake -DMLIR_DIR=/path/to/llvm-project/build/lib/cmake/mlir ..\n"
)
endif()
if (ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT)
# This function is defined in llvm_project.
# Define a dummy function for PYRUNTIME_LIGHT.
# If needed, the definition from llvm_project can be copied.
function(llvm_update_compile_flags name)
endfunction()
else()
# Must unset LLVM_DIR in cache. Otherwise, when MLIR_DIR changes LLVM_DIR
# won't change accordingly.
unset(LLVM_DIR CACHE)
if (NOT DEFINED MLIR_DIR)
message(FATAL_ERROR "MLIR_DIR is not configured but it is required. "
"Set the cmake option MLIR_DIR, e.g.,\n"
" cmake -DMLIR_DIR=/path/to/llvm-project/build/lib/cmake/mlir ..\n"
)
endif()

find_package(MLIR REQUIRED CONFIG)
find_package(MLIR REQUIRED CONFIG)

message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}")
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}")
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")

list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")

include(TableGen)
include(AddLLVM)
include(AddMLIR)
include(TableGen)
include(AddLLVM)
include(AddMLIR)

include(HandleLLVMOptions)
include(HandleLLVMOptions)

include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${MLIR_INCLUDE_DIRS})

add_definitions(${LLVM_DEFINITIONS})
add_definitions(${LLVM_DEFINITIONS})
endif()

set(BUILD_SHARED_LIBS ${LLVM_ENABLE_SHARED_LIBS} CACHE BOOL "" FORCE)
message(STATUS "BUILD_SHARED_LIBS : " ${BUILD_SHARED_LIBS})
Expand Down Expand Up @@ -158,7 +166,9 @@ function(add_onnx_mlir_library name)
)

if (NOT ARG_EXCLUDE_FROM_OM_LIBS)
set_property(GLOBAL APPEND PROPERTY ONNX_MLIR_LIBS ${name})
if (NOT ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT)
set_property(GLOBAL APPEND PROPERTY ONNX_MLIR_LIBS ${name})
endif()
endif()

add_library(${name} ${ARG_UNPARSED_ARGUMENTS})
Expand Down
39 changes: 39 additions & 0 deletions docs/BuildPyRuntimeLit.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# How to build and use PyRuntime lit

## Purpsoe

PyRuntime lit is a different way to build the original PyRuntime (src/Runtime/python).
All necessary dependence, such as llvm_project and onnx-mlir compiler is removed. The purpose is to easily build the python driver for the model execution on
different systems. Currently, only the OMTenserUtils (src/Runtime), Python driver (src/Runtime/python), third_party/onnx and third_party/pybind11 are built.

The build of PyRuntime lit is controlled by a CMake option: ONNX_MLIR_ENABLE_PYRUNTIME_LIT. Without this option to cmake, the whole system remains the same.

## Functionalities
1. Build the python driver without llvm_project and onnx-mlir compiler built.
2. The python driver can be used with utils/RunONNXModel.py, or onnxmlir python package.
3. With PyRuntime lit, the compiler has not been built locally and docker image of onnx-mlir has to be usd to compile the model. The onnxmlir package contains
the python code to use python docker package to perform the compilation. Alternatively, the old script, onnx-mlir/docker/onnx-mlir.py, can do the fulfill the same task with subprocess and docker CLI.

## How to use
You can find the script for build and run at "onnx-mlir/utils/build-pyruntime-lit.sh.
```
#!/bin/bash

# Assume you are in an empty directory for build in cloned onnx-mlir.
# Usually it is "your_path/onnx-mlir/build"
# then you can run this script as "../util/build-pyruntime-lit.sh"

cmake .. -DONNX_MLIR_ENABLE_PYRUNTIME_LIT=ON
make
make OMCreatePyRuntimePackage

# Install the package
pip3 install -e src/Runtime/python/onnxmlir
# -e is necessary for current package. Need to add resource description
# to install the pre-compiled binary

# Run test case
cd src/Runtime/python/onnxmlir/tests
python3 test_1.py
# Current limitation on where the model is
```
10 changes: 10 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# SPDX-License-Identifier: Apache-2.0

if (ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT)
add_subdirectory(Runtime)
# Accelerators introduces a target AcceleratorsInc. Define a dummy one here
add_custom_target(AcceleratorsInc
COMMAND echo "This is the dummy definition for AcceleratorsInc"
)
add_compile_definitions(ENABLE_PYRUNTIME_LIGHT)
return()
endif()

add_subdirectory(Accelerators)
add_subdirectory(Interface)
add_subdirectory(Dialect)
Expand Down
14 changes: 14 additions & 0 deletions src/Runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@

# SPDX-License-Identifier: Apache-2.0

if (NOT ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT)
add_subdirectory(jni)
add_subdirectory(omp)
endif()
add_subdirectory(python)

# TODO: should add for each accelerator its subdirectory that implements InitAccel##name
Expand Down Expand Up @@ -65,6 +67,17 @@ set_target_properties(OMTensorUtils
POSITION_INDEPENDENT_CODE TRUE
)

if (ONNX_MLIR_ENABLE_PYRUNTIME_LIGHT)
add_compile_definitions(ENABLE_PYRUNTIME_LIGHT)
add_onnx_mlir_library(OMExecutionSession
ExecutionSession.cpp

EXCLUDE_FROM_OM_LIBS

LINK_LIBS PUBLIC
OMTensorUtils
)
else()
add_onnx_mlir_library(OMExecutionSession
ExecutionSession.cpp

Expand All @@ -74,6 +87,7 @@ add_onnx_mlir_library(OMExecutionSession
OMTensorUtils
LLVMSupport
)
endif()
set_target_properties(OMExecutionSession
PROPERTIES
POSITION_INDEPENDENT_CODE TRUE
Expand Down
46 changes: 46 additions & 0 deletions src/Runtime/ExecutionSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
#include <sstream>
#include <vector>

#ifndef ENABLE_PYRUNTIME_LIGHT
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Path.h"
#else
#include <dlfcn.h>
#endif

#include "ExecutionSession.hpp"
#include "OMTensorListHelper.hpp"
Expand All @@ -44,16 +48,24 @@ void ExecutionSession::Init(

// If there is no tag, use the model filename without extension as a tag.
if (tag == "") {
// ToFix: equivalent implementation of llvm utilities.
// The would not be an urgent issue, because tag is usually "NONE"
#ifndef ENABLE_PYRUNTIME_LIGHT
std::string fname = llvm::sys::path::filename(sharedLibPath).str();
llvm::SmallString<256> fnameWithoutExt(fname);
llvm::sys::path::replace_extension(fnameWithoutExt, "");
tag = fnameWithoutExt.str().lower();
#endif
}

// tag = "NONE" to use functions without tag.
std::string lowDashTag;
// ToFix: equivalent implementation of llv::StringRef
#ifndef ENABLE_PYRUNTIME_LIGHT
// Assume tag is always NONE
if (!llvm::StringRef(tag).equals_insensitive("NONE"))
lowDashTag = "_" + tag;
#endif

#if defined(_WIN32)
// Use functions without tags on Windows since we cannot define at compile
Expand All @@ -63,31 +75,55 @@ void ExecutionSession::Init(
#endif

// Init symbols used by execution session.
#ifndef ENABLE_PYRUNTIME_LIGHT
_sharedLibraryHandle =
llvm::sys::DynamicLibrary::getLibrary(sharedLibPath.c_str());
if (!_sharedLibraryHandle.isValid())
throw std::runtime_error(reportLibraryOpeningError(sharedLibPath));
#else
// Copy code from llvm/lib/Support/DynamicLibrary.cpp, especially the flags
// ToFix: copy the lock related code too.
_sharedLibraryHandle = dlopen(sharedLibPath.c_str(), RTLD_LAZY | RTLD_GLOBAL);
if (!_sharedLibraryHandle)
throw std::runtime_error(reportLibraryOpeningError(sharedLibPath));
#endif

std::string queryEntryPointsNameWithTag = _queryEntryPointsName + lowDashTag;
#ifndef ENABLE_PYRUNTIME_LIGHT
_queryEntryPointsFunc = reinterpret_cast<queryEntryPointsFuncType>(
_sharedLibraryHandle.getAddressOfSymbol(
queryEntryPointsNameWithTag.c_str()));
#else
_queryEntryPointsFunc = reinterpret_cast<queryEntryPointsFuncType>(
dlsym(_sharedLibraryHandle, queryEntryPointsNameWithTag.c_str()));
#endif

if (!_queryEntryPointsFunc)
throw std::runtime_error(
reportSymbolLoadingError(queryEntryPointsNameWithTag));

std::string inputSignatureNameWithTag = _inputSignatureName + lowDashTag;
#ifndef ENABLE_PYRUNTIME_LIGHT
_inputSignatureFunc = reinterpret_cast<signatureFuncType>(
_sharedLibraryHandle.getAddressOfSymbol(
inputSignatureNameWithTag.c_str()));
#else
_inputSignatureFunc = reinterpret_cast<signatureFuncType>(
dlsym(_sharedLibraryHandle, inputSignatureNameWithTag.c_str()));
#endif
if (!_inputSignatureFunc)
throw std::runtime_error(
reportSymbolLoadingError(inputSignatureNameWithTag));

std::string outputSignatureNameWithTag = _outputSignatureName + lowDashTag;
#ifndef ENABLE_PYRUNTIME_LIGHT
_outputSignatureFunc = reinterpret_cast<signatureFuncType>(
_sharedLibraryHandle.getAddressOfSymbol(
outputSignatureNameWithTag.c_str()));
#else
_outputSignatureFunc = reinterpret_cast<signatureFuncType>(
dlsym(_sharedLibraryHandle, outputSignatureNameWithTag.c_str()));
#endif
if (!_outputSignatureFunc)
throw std::runtime_error(
reportSymbolLoadingError(outputSignatureNameWithTag));
Expand All @@ -114,8 +150,13 @@ void ExecutionSession::Init(
}

ExecutionSession::~ExecutionSession() {
#ifndef ENABLE_PYRUNTIME_LIGHT
if (_sharedLibraryHandle.isValid())
llvm::sys::DynamicLibrary::closeLibrary(_sharedLibraryHandle);
#else
if (!_sharedLibraryHandle)
dlclose(_sharedLibraryHandle);
#endif
}

// =============================================================================
Expand All @@ -132,8 +173,13 @@ const std::string *ExecutionSession::queryEntryPoints(
void ExecutionSession::setEntryPoint(const std::string &entryPointName) {
if (!isInitialized)
throw std::runtime_error(reportInitError());
#ifndef ENABLE_PYRUNTIME_LIGHT
_entryPointFunc = reinterpret_cast<entryPointFuncType>(
_sharedLibraryHandle.getAddressOfSymbol(entryPointName.c_str()));
#else
_entryPointFunc = reinterpret_cast<entryPointFuncType>(
dlsym(_sharedLibraryHandle, entryPointName.c_str()));
#endif
if (!_entryPointFunc)
throw std::runtime_error(reportSymbolLoadingError(entryPointName));
_entryPointName = entryPointName;
Expand Down
Loading

0 comments on commit e24fd5c

Please sign in to comment.
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy