GraphDump.cpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. /*
  2. * Copyright (C) 2017 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #define LOG_TAG "GraphDump"
  17. #include "GraphDump.h"
  18. #include "HalInterfaces.h"
  19. #include <android-base/logging.h>
  20. #include <set>
  21. #include <iostream>
  22. #include <sstream>
  23. namespace android {
  24. namespace nn {
  25. // class Dumper is a wrapper around an std::ostream (if instantiated
  26. // with a pointer to a stream) or around LOG(INFO) (otherwise).
  27. //
  28. // Send fragments of output to it with operator<<(), as per usual
  29. // stream conventions. Unlike with LOG(INFO), there is no implicit
  30. // end-of-line. To end a line, send Dumper::endl.
  31. //
  32. // Example:
  33. //
  34. // Dumper d(nullptr); // will go to LOG(INFO)
  35. // d << "These words are";
  36. // d << " all" << " on";
  37. // d << " the same line." << Dumper::endl;
  38. //
  39. namespace {
  40. class Dumper {
  41. public:
  42. Dumper(std::ostream* outStream) : mStream(outStream) { }
  43. Dumper(const Dumper&) = delete;
  44. void operator=(const Dumper&) = delete;
  45. template <typename T>
  46. Dumper& operator<<(const T& val) {
  47. mStringStream << val;
  48. return *this;
  49. }
  50. class EndlType { };
  51. Dumper& operator<<(EndlType) {
  52. if (mStream) {
  53. *mStream << mStringStream.str() << std::endl;
  54. } else {
  55. // TODO: There is a limit of how long a single LOG line
  56. // can be; extra characters are truncated. (See
  57. // LOGGER_ENTRY_MAX_PAYLOAD and LOGGER_ENTRY_MAX_LEN.) We
  58. // may want to figure out the linebreak rules for the .dot
  59. // format and try to ensure that we generate correct .dot
  60. // output whose lines do not exceed some maximum length.
  61. // The intelligence for breaking the lines might have to
  62. // live in graphDump() rather than in the Dumper class, so
  63. // that it can be sensitive to the .dot format.
  64. LOG(INFO) << mStringStream.str();
  65. }
  66. std::ostringstream empty;
  67. std::swap(mStringStream, empty);
  68. return *this;
  69. }
  70. static const EndlType endl;
  71. private:
  72. std::ostream* mStream;
  73. std::ostringstream mStringStream;
  74. };
  75. const Dumper::EndlType Dumper::endl;
  76. }
  77. // Provide short name for OperandType value.
  78. static std::string translate(OperandType type) {
  79. switch (type) {
  80. case OperandType::FLOAT32: return "F32";
  81. case OperandType::INT32: return "I32";
  82. case OperandType::UINT32: return "U32";
  83. case OperandType::TENSOR_FLOAT32: return "TF32";
  84. case OperandType::TENSOR_INT32: return "TI32";
  85. case OperandType::TENSOR_QUANT8_ASYMM: return "TQ8A";
  86. case OperandType::OEM: return "OEM";
  87. case OperandType::TENSOR_OEM_BYTE: return "TOEMB";
  88. default: return toString(type);
  89. }
  90. }
  91. // If the specified Operand of the specified Model has OperandType
  92. // nnType corresponding to C++ type cppType and is of
  93. // OperandLifeTime::CONSTANT_COPY, then write the Operand's value to
  94. // the Dumper.
  95. namespace {
  96. template<OperandType nnType, typename cppType>
  97. void tryValueDump(Dumper& dump, const Model& model, const Operand& opnd) {
  98. if (opnd.type != nnType ||
  99. opnd.lifetime != OperandLifeTime::CONSTANT_COPY ||
  100. opnd.location.length != sizeof(cppType)) {
  101. return;
  102. }
  103. cppType val;
  104. memcpy(&val, &model.operandValues[opnd.location.offset], sizeof(cppType));
  105. dump << " = " << val;
  106. }
  107. }
  108. void graphDump(const char* name, const Model& model, std::ostream* outStream) {
  109. // Operand nodes are named "d" (operanD) followed by operand index.
  110. // Operation nodes are named "n" (operatioN) followed by operation index.
  111. // (These names are not the names that are actually displayed -- those
  112. // names are given by the "label" attribute.)
  113. Dumper dump(outStream);
  114. dump << "// " << name << Dumper::endl;
  115. dump << "digraph {" << Dumper::endl;
  116. // model inputs and outputs
  117. std::set<uint32_t> modelIO;
  118. for (unsigned i = 0, e = model.inputIndexes.size(); i < e; i++) {
  119. modelIO.insert(model.inputIndexes[i]);
  120. }
  121. for (unsigned i = 0, e = model.outputIndexes.size(); i < e; i++) {
  122. modelIO.insert(model.outputIndexes[i]);
  123. }
  124. // model operands
  125. for (unsigned i = 0, e = model.operands.size(); i < e; i++) {
  126. dump << " d" << i << " [";
  127. if (modelIO.count(i)) {
  128. dump << "style=filled fillcolor=black fontcolor=white ";
  129. }
  130. dump << "label=\"" << i;
  131. const Operand& opnd = model.operands[i];
  132. const char* kind = nullptr;
  133. switch (opnd.lifetime) {
  134. case OperandLifeTime::CONSTANT_COPY:
  135. kind = "COPY";
  136. break;
  137. case OperandLifeTime::CONSTANT_REFERENCE:
  138. kind = "REF";
  139. break;
  140. case OperandLifeTime::NO_VALUE:
  141. kind = "NO";
  142. break;
  143. default:
  144. // nothing interesting
  145. break;
  146. }
  147. if (kind) {
  148. dump << ": " << kind;
  149. }
  150. dump << "\\n" << translate(opnd.type);
  151. tryValueDump<OperandType::FLOAT32, float>(dump, model, opnd);
  152. tryValueDump<OperandType::INT32, int>(dump, model, opnd);
  153. tryValueDump<OperandType::UINT32, unsigned>(dump, model, opnd);
  154. if (opnd.dimensions.size()) {
  155. dump << "(";
  156. for (unsigned i = 0, e = opnd.dimensions.size(); i < e; i++) {
  157. if (i > 0) {
  158. dump << "x";
  159. }
  160. dump << opnd.dimensions[i];
  161. }
  162. dump << ")";
  163. }
  164. dump << "\"]" << Dumper::endl;
  165. }
  166. // model operations
  167. for (unsigned i = 0, e = model.operations.size(); i < e; i++) {
  168. const Operation& operation = model.operations[i];
  169. dump << " n" << i << " [shape=box";
  170. const uint32_t maxArity = std::max(operation.inputs.size(), operation.outputs.size());
  171. if (maxArity > 1) {
  172. if (maxArity == operation.inputs.size()) {
  173. dump << " ordering=in";
  174. } else {
  175. dump << " ordering=out";
  176. }
  177. }
  178. dump << " label=\"" << i << ": "
  179. << toString(operation.type) << "\"]" << Dumper::endl;
  180. {
  181. // operation inputs
  182. for (unsigned in = 0, inE = operation.inputs.size(); in < inE; in++) {
  183. dump << " d" << operation.inputs[in] << " -> n" << i;
  184. if (inE > 1) {
  185. dump << " [label=" << in << "]";
  186. }
  187. dump << Dumper::endl;
  188. }
  189. }
  190. {
  191. // operation outputs
  192. for (unsigned out = 0, outE = operation.outputs.size(); out < outE; out++) {
  193. dump << " n" << i << " -> d" << operation.outputs[out];
  194. if (outE > 1) {
  195. dump << " [label=" << out << "]";
  196. }
  197. dump << Dumper::endl;
  198. }
  199. }
  200. }
  201. dump << "}" << Dumper::endl;
  202. }
  203. } // namespace nn
  204. } // namespace android