Logo ROOT   6.10/00
Reference Guide
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
TDFInterface.cxx
Go to the documentation of this file.
1 // Author: Enrico Guiraud, Danilo Piparo CERN 03/2017
2 
3 /*************************************************************************
4  * Copyright (C) 1995-2016, Rene Brun and Fons Rademakers. *
5  * All rights reserved. *
6  * *
7  * For the licensing terms see $ROOTSYS/LICENSE. *
8  * For the list of contributors see $ROOTSYS/README/CREDITS. *
9  *************************************************************************/
10 
11 #include "TClass.h"
12 #include "TRegexp.h"
13 
14 #include "ROOT/TDFInterface.hxx"
15 
16 #include <vector>
17 #include <string>
18 using namespace ROOT::Experimental::TDF;
19 using namespace ROOT::Internal::TDF;
20 using namespace ROOT::Detail::TDF;
21 
22 namespace ROOT {
23 namespace Experimental {
24 namespace TDF {
25 // extern templates
26 template class TInterface<TLoopManager>;
27 template class TInterface<TFilterBase>;
28 template class TInterface<TCustomColumnBase>;
29 }
30 }
31 
32 namespace Internal {
33 namespace TDF {
34 // Match expression against names of branches passed as parameter
35 // Return vector of names of the branches used in the expression
36 std::vector<std::string> GetUsedBranchesNames(const std::string expression, TObjArray *branches,
37  const std::vector<std::string> &tmpBranches)
38 {
39  // Check what branches and temporary branches are used in the expression
40  // To help matching the regex
41  std::string paddedExpr = " " + expression + " ";
42  int paddedExprLen = paddedExpr.size();
43  static const std::string regexBit("[^a-zA-Z0-9_]");
44  std::vector<std::string> usedBranches;
45  for (auto bro : *branches) {
46  auto brName = bro->GetName();
47  std::string bNameRegexContent = regexBit + brName + regexBit;
48  TRegexp bNameRegex(bNameRegexContent.c_str());
49  if (-1 != bNameRegex.Index(paddedExpr.c_str(), &paddedExprLen)) {
50  usedBranches.emplace_back(brName);
51  }
52  }
53  for (auto brName : tmpBranches) {
54  std::string bNameRegexContent = regexBit + brName + regexBit;
55  TRegexp bNameRegex(bNameRegexContent.c_str());
56  if (-1 != bNameRegex.Index(paddedExpr.c_str(), &paddedExprLen)) {
57  usedBranches.emplace_back(brName.c_str());
58  }
59  }
60  return usedBranches;
61 }
62 
63 // Jit a string filter or a string temporary column, call this->Define or this->Filter as needed
64 // Return pointer to the new functional chain node returned by the call, cast to Long_t
65 Long_t JitTransformation(void *thisPtr, const std::string &methodName, const std::string &nodeTypeName,
66  const std::string &name, const std::string &expression, TObjArray *branches,
67  const std::vector<std::string> &tmpBranches,
68  const std::map<std::string, TmpBranchBasePtr_t> &tmpBookedBranches, TTree *tree)
69 {
70  auto usedBranches = GetUsedBranchesNames(expression, branches, tmpBranches);
71  auto exprNeedsVariables = !usedBranches.empty();
72 
73  // Move to the preparation of the jitting
74  // We put all of the jitted entities in a namespace called
75  // __tdf_filter_N, where N is a monotonically increasing index.
76  TInterpreter::EErrorCode interpErrCode;
77  std::vector<std::string> usedBranchesTypes;
78  std::stringstream ss;
79  static unsigned int iNs = 0U;
80  ss << "__tdf_" << iNs++;
81  const auto nsName = ss.str();
82  ss.str("");
83 
84  if (exprNeedsVariables) {
85  // Declare a namespace and inside it the variables in the expression
86  ss << "namespace " << nsName;
87  ss << " {\n";
88  for (auto brName : usedBranches) {
89  // The map is a const reference, so no operator[]
90  auto tmpBrIt = tmpBookedBranches.find(brName);
91  auto tmpBr = tmpBrIt == tmpBookedBranches.end() ? nullptr : tmpBrIt->second.get();
92  auto brTypeName = ColumnName2ColumnTypeName(brName, tree, tmpBr);
93  ss << brTypeName << " " << brName << ";\n";
94  usedBranchesTypes.emplace_back(brTypeName);
95  }
96  ss << "}";
97  auto variableDeclarations = ss.str();
98  ss.str("");
99  // We need ProcessLine to trigger auto{parsing,loading} where needed
100  gInterpreter->ProcessLine(variableDeclarations.c_str(), &interpErrCode);
101  if (TInterpreter::EErrorCode::kNoError != interpErrCode) {
102  std::string msg = "Cannot declare these variables ";
103  msg += " ";
104  msg += variableDeclarations;
105  if (TInterpreter::EErrorCode::kNoError != interpErrCode) {
106  msg += "\nInterpreter error code is " + std::to_string(interpErrCode) + ".";
107  }
108  throw std::runtime_error(msg);
109  }
110  }
111 
112  // Declare within the same namespace, the expression to make sure it
113  // is proper C++
114  ss << "namespace " << nsName << "{ auto res = " << expression << ";}\n";
115  // Headers must have been parsed and libraries loaded: we can use Declare
116  if (!gInterpreter->Declare(ss.str().c_str())) {
117  std::string msg = "Cannot interpret this expression: ";
118  msg += " ";
119  msg += ss.str();
120  throw std::runtime_error(msg);
121  }
122 
123  // Now we build the lambda and we invoke the method with it in the jitted world
124  ss.str("");
125  ss << "[](";
126  for (unsigned int i = 0; i < usedBranchesTypes.size(); ++i) {
127  // We pass by reference to avoid expensive copies
128  ss << usedBranchesTypes[i] << "& " << usedBranches[i] << ", ";
129  }
130  if (!usedBranchesTypes.empty()) ss.seekp(-2, ss.cur);
131  ss << "){ return " << expression << ";}";
132  auto filterLambda = ss.str();
133 
134  // Here we have two cases: filter and column
135  ss.str("");
136  ss << "((" << nodeTypeName << "*)" << thisPtr << ")->" << methodName << "(";
137  if (methodName == "Define") {
138  ss << "\"" << name << "\", ";
139  }
140  ss << filterLambda << ", {";
141  for (auto brName : usedBranches) {
142  ss << "\"" << brName << "\", ";
143  }
144  if (exprNeedsVariables) ss.seekp(-2, ss.cur); // remove the last ",
145  ss << "}";
146 
147  if (methodName == "Filter") {
148  ss << ", \"" << name << "\"";
149  }
150 
151  ss << ");";
152 
153  auto retVal = gInterpreter->ProcessLine(ss.str().c_str(), &interpErrCode);
154  if (TInterpreter::EErrorCode::kNoError != interpErrCode || !retVal) {
155  std::string msg = "Cannot interpret the invocation to " + methodName + ": ";
156  msg += " ";
157  msg += ss.str();
158  if (TInterpreter::EErrorCode::kNoError != interpErrCode) {
159  msg += "\nInterpreter error code is " + std::to_string(interpErrCode) + ".";
160  }
161  throw std::runtime_error(msg);
162  }
163  return retVal;
164 }
165 
166 // Jit and call something equivalent to "this->BuildAndBook<BranchTypes...>(params...)"
167 // (see comments in the body for actual jitted code)
168 void JitBuildAndBook(const ColumnNames_t &bl, const std::string &nodeTypename, void *thisPtr, const std::type_info &art,
169  const std::type_info &at, const void *r, TTree *tree, unsigned int nSlots,
170  const std::map<std::string, TmpBranchBasePtr_t> &tmpBranches)
171 {
172  gInterpreter->ProcessLine("#include \"ROOT/TDataFrame.hxx\"");
173  auto nBranches = bl.size();
174 
175  // retrieve pointers to temporary columns (null if the column is not temporary)
176  std::vector<TCustomColumnBase *> tmpBranchPtrs(nBranches, nullptr);
177  for (auto i = 0u; i < nBranches; ++i) {
178  auto tmpBranchIt = tmpBranches.find(bl[i]);
179  if (tmpBranchIt != tmpBranches.end()) tmpBranchPtrs[i] = tmpBranchIt->second.get();
180  }
181 
182  // retrieve branch type names as strings
183  std::vector<std::string> branchTypeNames(nBranches);
184  for (auto i = 0u; i < nBranches; ++i) {
185  const auto branchTypeName = ColumnName2ColumnTypeName(bl[i], tree, tmpBranchPtrs[i]);
186  if (branchTypeName.empty()) {
187  std::string exceptionText = "The type of column ";
188  exceptionText += bl[i];
189  exceptionText += " could not be guessed. Please specify one.";
190  throw std::runtime_error(exceptionText.c_str());
191  }
192  branchTypeNames[i] = branchTypeName;
193  }
194 
195  // retrieve type of result of the action as a string
196  auto actionResultTypeClass = TClass::GetClass(art);
197  if (!actionResultTypeClass) {
198  std::string exceptionText = "An error occurred while inferring the result type of an operation.";
199  throw std::runtime_error(exceptionText.c_str());
200  }
201  const auto actionResultTypeName = actionResultTypeClass->GetName();
202 
203  // retrieve type of action as a string
204  auto actionTypeClass = TClass::GetClass(at);
205  if (!actionTypeClass) {
206  std::string exceptionText = "An error occurred while inferring the action type of the operation.";
207  throw std::runtime_error(exceptionText.c_str());
208  }
209  const auto actionTypeName = actionTypeClass->GetName();
210 
211  // createAction_str will contain the following:
212  // ROOT::Internal::TDF::CallBuildAndBook<nodeType, actionType, branchType1, branchType2...>(
213  // reinterpret_cast<nodeType*>(thisPtr), *reinterpret_cast<ROOT::ColumnNames_t*>(&bl),
214  // *reinterpret_cast<actionResultType*>(r), reinterpret_cast<ActionType*>(nullptr))
215  std::stringstream createAction_str;
216  createAction_str << "ROOT::Internal::TDF::CallBuildAndBook<" << nodeTypename << ", " << actionTypeName;
217  for (auto &branchTypeName : branchTypeNames) createAction_str << ", " << branchTypeName;
218  createAction_str << ">("
219  << "reinterpret_cast<" << nodeTypename << "*>(" << thisPtr << "), "
220  << "*reinterpret_cast<ROOT::Detail::TDF::ColumnNames_t*>(" << &bl << "), " << nSlots
221  << ", *reinterpret_cast<" << actionResultTypeName << "*>(" << r << "));";
222  auto error = TInterpreter::EErrorCode::kNoError;
223  gInterpreter->ProcessLine(createAction_str.str().c_str(), &error);
224  if (error) {
225  std::string exceptionText = "An error occurred while jitting this action:\n";
226  exceptionText += createAction_str.str();
227  throw std::runtime_error(exceptionText.c_str());
228  }
229 }
230 } // end ns TDF
231 } // end ns Internal
232 } // end ns ROOT
An array of TObjects.
Definition: TObjArray.h:37
Regular expression class.
Definition: TRegexp.h:31
#define gInterpreter
Definition: TInterpreter.h:499
TRandom2 r(17)
std::string ColumnName2ColumnTypeName(const std::string &colName, TTree *tree, TCustomColumnBase *tmpBranch)
Return a string containing the type of the given branch.
Definition: TDFUtils.cxx:30
Long_t JitTransformation(void *thisPtr, const std::string &methodName, const std::string &nodeTypeName, const std::string &name, const std::string &expression, TObjArray *branches, const std::vector< std::string > &tmpBranches, const std::map< std::string, TmpBranchBasePtr_t > &tmpBookedBranches, TTree *tree)
long Long_t
Definition: RtypesCore.h:50
static TClass * GetClass(const char *name, Bool_t load=kTRUE, Bool_t silent=kFALSE)
Static method returning pointer to TClass of the specified class name.
Definition: TClass.cxx:2885
void JitBuildAndBook(const ColumnNames_t &bl, const std::string &nodeTypename, void *thisPtr, const std::type_info &art, const std::type_info &at, const void *r, TTree *tree, unsigned int nSlots, const std::map< std::string, TmpBranchBasePtr_t > &tmpBranches)
std::vector< std::string > GetUsedBranchesNames(const std::string, TObjArray *, const std::vector< std::string > &)
char name[80]
Definition: TGX11.cxx:109
The public interface to the TDataFrame federation of classes.