GlobalMergePass.cpp 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. /*
  2. * Copyright 2016-2017, The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "GlobalMergePass.h"
  17. #include "llvm/IR/Constants.h"
  18. #include "llvm/IR/DataLayout.h"
  19. #include "llvm/IR/GlobalVariable.h"
  20. #include "llvm/IR/IRBuilder.h"
  21. #include "llvm/IR/Instructions.h"
  22. #include "llvm/IR/Module.h"
  23. #include "llvm/Pass.h"
  24. #include "llvm/Support/Debug.h"
  25. #include "llvm/Support/raw_ostream.h"
  26. #include "Context.h"
  27. #include "RSAllocationUtils.h"
  28. #include <functional>
  29. #define DEBUG_TYPE "rs2spirv-global-merge"
  30. using namespace llvm;
  31. namespace rs2spirv {
  32. namespace {
  33. class GlobalMergePass : public ModulePass {
  34. public:
  35. static char ID;
  36. GlobalMergePass(bool CPU = false) : ModulePass(ID), mForCPU(CPU) {}
  37. const char *getPassName() const override { return "GlobalMergePass"; }
  38. bool runOnModule(Module &M) override {
  39. DEBUG(dbgs() << "RS2SPIRVGlobalMergePass\n");
  40. SmallVector<GlobalVariable *, 8> Globals;
  41. if (!collectGlobals(M, Globals)) {
  42. return false; // Module not modified.
  43. }
  44. SmallVector<Type *, 8> Tys;
  45. Tys.reserve(Globals.size());
  46. Context &RS2SPIRVCtxt = Context::getInstance();
  47. uint32_t index = 0;
  48. for (GlobalVariable *GV : Globals) {
  49. Tys.push_back(GV->getValueType());
  50. const char *name = GV->getName().data();
  51. RS2SPIRVCtxt.addExportVarIndex(name, index);
  52. index++;
  53. }
  54. LLVMContext &LLVMCtxt = M.getContext();
  55. StructType *MergedTy = StructType::create(LLVMCtxt, "struct.__GPUBuffer");
  56. MergedTy->setBody(Tys, false);
  57. // Size calculation has to consider data layout
  58. const DataLayout &DL = M.getDataLayout();
  59. const uint64_t BufferSize = DL.getTypeAllocSize(MergedTy);
  60. RS2SPIRVCtxt.setGlobalSize(BufferSize);
  61. Type *BufferVarTy = mForCPU ? static_cast<Type *>(PointerType::getUnqual(
  62. Type::getInt8Ty(M.getContext())))
  63. : static_cast<Type *>(MergedTy);
  64. GlobalVariable *MergedGV =
  65. new GlobalVariable(M, BufferVarTy, false, GlobalValue::ExternalLinkage,
  66. nullptr, "__GPUBlock");
  67. // For CPU, create a constant struct for initial values, which has each of
  68. // its fields initialized to the original value of the corresponding global
  69. // variable.
  70. // During the script initialization, the driver should copy these initial
  71. // values to the global buffer.
  72. if (mForCPU) {
  73. CreateInitFunction(LLVMCtxt, M, MergedGV, MergedTy, BufferSize, Globals);
  74. }
  75. const bool forCPU = mForCPU;
  76. IntegerType *const Int32Ty = Type::getInt32Ty(LLVMCtxt);
  77. ConstantInt *const Zero = ConstantInt::get(Int32Ty, 0);
  78. Value *Idx[] = {Zero, nullptr};
  79. auto InstMaker = [forCPU, MergedGV, MergedTy,
  80. &Idx](Instruction *InsertBefore) {
  81. Value *Base = MergedGV;
  82. if (forCPU) {
  83. LoadInst *Load = new LoadInst(MergedGV, "", InsertBefore);
  84. DEBUG(Load->dump());
  85. Base = new BitCastInst(Load, PointerType::getUnqual(MergedTy), "",
  86. InsertBefore);
  87. DEBUG(Base->dump());
  88. }
  89. GetElementPtrInst *GEP = GetElementPtrInst::CreateInBounds(
  90. MergedTy, Base, Idx, "", InsertBefore);
  91. DEBUG(GEP->dump());
  92. return GEP;
  93. };
  94. for (size_t i = 0, e = Globals.size(); i != e; ++i) {
  95. GlobalVariable *G = Globals[i];
  96. Idx[1] = ConstantInt::get(Int32Ty, i);
  97. ReplaceAllUsesWithNewInstructions(G, std::cref(InstMaker));
  98. G->eraseFromParent();
  99. }
  100. // Return true, as the pass modifies module.
  101. return true;
  102. }
  103. private:
  104. // In the User of Value Old, replaces all references of Old with Value New
  105. static inline void ReplaceUse(User *U, Value *Old, Value *New) {
  106. for (unsigned i = 0, n = U->getNumOperands(); i < n; ++i) {
  107. if (U->getOperand(i) == Old) {
  108. U->getOperandUse(i) = New;
  109. }
  110. }
  111. }
  112. // Replaces each use of V with new instructions created by
  113. // funcCreateAndInsert and inserted right before that use. In the cases where
  114. // the use is not an instruction, but a constant expression, recursively
  115. // replaces that constant expression with a newly constructed equivalent
  116. // instruction, before replacing V in that new instruction.
  117. static inline void ReplaceAllUsesWithNewInstructions(
  118. Value *V,
  119. std::function<Instruction *(Instruction *)> funcCreateAndInsert) {
  120. SmallVector<User *, 8> Users(V->user_begin(), V->user_end());
  121. for (User *U : Users) {
  122. if (Instruction *Inst = dyn_cast<Instruction>(U)) {
  123. DEBUG(dbgs() << "\nBefore replacement:\n");
  124. DEBUG(Inst->dump());
  125. DEBUG(dbgs() << "----\n");
  126. ReplaceUse(U, V, funcCreateAndInsert(Inst));
  127. DEBUG(Inst->dump());
  128. } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U)) {
  129. auto InstMaker([CE, V, &funcCreateAndInsert](Instruction *UserOfU) {
  130. Instruction *Inst = CE->getAsInstruction();
  131. Inst->insertBefore(UserOfU);
  132. ReplaceUse(Inst, V, funcCreateAndInsert(Inst));
  133. DEBUG(Inst->dump());
  134. return Inst;
  135. });
  136. ReplaceAllUsesWithNewInstructions(U, InstMaker);
  137. } else {
  138. DEBUG(U->dump());
  139. llvm_unreachable("Expecting only Instruction or ConstantExpr");
  140. }
  141. }
  142. }
  143. static inline void
  144. CreateInitFunction(LLVMContext &LLVMCtxt, Module &M, GlobalVariable *MergedGV,
  145. StructType *MergedTy, const uint64_t BufferSize,
  146. const SmallVectorImpl<GlobalVariable *> &Globals) {
  147. SmallVector<Constant *, 8> Initializers;
  148. Initializers.reserve(Globals.size());
  149. for (size_t i = 0, e = Globals.size(); i != e; ++i) {
  150. GlobalVariable *G = Globals[i];
  151. Initializers.push_back(G->getInitializer());
  152. }
  153. ArrayRef<Constant *> ArrInit(Initializers.begin(), Initializers.end());
  154. Constant *MergedInitializer = ConstantStruct::get(MergedTy, ArrInit);
  155. GlobalVariable *MergedInit =
  156. new GlobalVariable(M, MergedTy, true, GlobalValue::InternalLinkage,
  157. MergedInitializer, "__GPUBlock0");
  158. Function *UserInit = M.getFunction("init");
  159. // If there is no user-defined init() function, make the new global
  160. // initialization function the init().
  161. StringRef FName(UserInit ? ".rsov.global_init" : "init");
  162. Function *Func;
  163. FunctionType *FTy = FunctionType::get(Type::getVoidTy(LLVMCtxt), false);
  164. Func = Function::Create(FTy, GlobalValue::ExternalLinkage, FName, &M);
  165. BasicBlock *Blk = BasicBlock::Create(LLVMCtxt, "entry", Func);
  166. IRBuilder<> LLVMIRBuilder(Blk);
  167. LoadInst *Load = LLVMIRBuilder.CreateLoad(MergedGV);
  168. LLVMIRBuilder.CreateMemCpy(Load, MergedInit, BufferSize, 0);
  169. LLVMIRBuilder.CreateRetVoid();
  170. // If there is a user-defined init() function, add a call to the global
  171. // initialization function in the beginning of that function.
  172. if (UserInit) {
  173. BasicBlock &EntryBlk = UserInit->getEntryBlock();
  174. CallInst::Create(Func, {}, "", &EntryBlk.front());
  175. }
  176. }
  177. bool collectGlobals(Module &M, SmallVectorImpl<GlobalVariable *> &Globals) {
  178. for (GlobalVariable &GV : M.globals()) {
  179. assert(!GV.hasComdat() && "global variable has a comdat section");
  180. assert(!GV.hasSection() && "global variable has a non-default section");
  181. assert(!GV.isDeclaration() && "global variable is only a declaration");
  182. assert(!GV.isThreadLocal() && "global variable is thread-local");
  183. assert(GV.getType()->getAddressSpace() == 0 &&
  184. "global variable has non-default address space");
  185. // TODO: Constants accessed by kernels should be handled differently
  186. if (GV.isConstant()) {
  187. continue;
  188. }
  189. // Global Allocations are handled differently in separate passes
  190. if (isRSAllocation(GV)) {
  191. continue;
  192. }
  193. Globals.push_back(&GV);
  194. }
  195. return !Globals.empty();
  196. }
  197. bool mForCPU;
  198. };
  199. } // namespace
  200. char GlobalMergePass::ID = 0;
  201. ModulePass *createGlobalMergePass(bool CPU) { return new GlobalMergePass(CPU); }
  202. } // namespace rs2spirv