Logo ROOT   6.10/00
Reference Guide
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
testPyKerasMulticlass.C
Go to the documentation of this file.
1 #include <iostream>
2 
3 #include "TString.h"
4 #include "TFile.h"
5 #include "TTree.h"
6 #include "TSystem.h"
7 #include "TROOT.h"
8 #include "TMVA/Factory.h"
9 #include "TMVA/Reader.h"
10 #include "TMVA/DataLoader.h"
11 #include "TMVA/PyMethodBase.h"
12 
14 from keras.models import Sequential\n\
15 from keras.layers.core import Dense, Activation\n\
16 from keras import initializations\n\
17 from keras.optimizers import Adam\n\
18 \n\
19 model = Sequential()\n\
20 model.add(Dense(64, init=\"glorot_normal\", activation=\"relu\", input_dim=4))\n\
21 model.add(Dense(4, init=\"glorot_normal\", activation=\"softmax\"))\n\
22 model.compile(loss=\"categorical_crossentropy\", optimizer=Adam(), metrics=[\"accuracy\",])\n\
23 model.save(\"kerasModelMulticlass.h5\")\n";
24 
26  // Get data file
27  std::cout << "Get test data..." << std::endl;
28  TString fname = "./tmva_example_multiple_background.root";
29  if (gSystem->AccessPathName(fname)){ // file does not exist in local directory
30  std::cout << "Create multiclass test data..." << std::endl;
31  TString createDataMacro = TString(gROOT->GetTutorialsDir()) + "/tmva/createData.C";
32  gROOT->ProcessLine(TString::Format(".L %s",createDataMacro.Data()));
33  gROOT->ProcessLine("create_MultipleBackground(200)");
34  std::cout << "Created " << fname << " for tests of the multiclass features" << std::endl;
35  }
36  TFile *input = TFile::Open(fname);
37 
38  // Build model from python file
39  std::cout << "Generate keras model..." << std::endl;
40  UInt_t ret;
41  ret = gSystem->Exec("echo '"+pythonSrc+"' > generateKerasModelMulticlass.py");
42  if(ret!=0){
43  std::cout << "[ERROR] Failed to write python code to file" << std::endl;
44  return 1;
45  }
46  ret = gSystem->Exec("python generateKerasModelMulticlass.py");
47  if(ret!=0){
48  std::cout << "[ERROR] Failed to generate model using python" << std::endl;
49  return 1;
50  }
51 
52  // Setup PyMVA and factory
53  std::cout << "Setup TMVA..." << std::endl;
55  TFile* outputFile = TFile::Open("ResultsTestPyKerasMulticlass.root", "RECREATE");
56  TMVA::Factory *factory = new TMVA::Factory("testPyKerasMulticlass", outputFile,
57  "!V:Silent:Color:!DrawProgressBar:AnalysisType=multiclass");
58 
59  // Load data
60  TMVA::DataLoader *dataloader = new TMVA::DataLoader("datasetTestPyKerasMulticlass");
61 
62  TTree *signal = (TTree*)input->Get("TreeS");
63  TTree *background0 = (TTree*)input->Get("TreeB0");
64  TTree *background1 = (TTree*)input->Get("TreeB1");
65  TTree *background2 = (TTree*)input->Get("TreeB2");
66  dataloader->AddTree(signal, "Signal");
67  dataloader->AddTree(background0, "Background_0");
68  dataloader->AddTree(background1, "Background_1");
69  dataloader->AddTree(background2, "Background_2");
70 
71  dataloader->AddVariable("var1");
72  dataloader->AddVariable("var2");
73  dataloader->AddVariable("var3");
74  dataloader->AddVariable("var4");
75 
76  dataloader->PrepareTrainingAndTestTree("",
77  "SplitMode=Random:NormMode=NumEvents:!V");
78 
79  // Book and train method
80  factory->BookMethod(dataloader, TMVA::Types::kPyKeras, "PyKeras",
81  "!H:!V:VarTransform=D,G:FilenameModel=kerasModelMulticlass.h5:FilenameTrainedModel=trainedKerasModelMulticlass.h5:NumEpochs=20:BatchSize=32:SaveBestOnly=false:Verbose=0");
82  std::cout << "Train model..." << std::endl;
83  factory->TrainAllMethods();
84 
85  // Clean-up
86  delete factory;
87  delete dataloader;
88  delete outputFile;
89 
90  // Setup reader
91  UInt_t numEvents = 100;
92  std::cout << "Run reader and classify " << numEvents << " events..." << std::endl;
93  TMVA::Reader *reader = new TMVA::Reader("!Color:Silent");
94  Float_t vars[4];
95  reader->AddVariable("var1", vars+0);
96  reader->AddVariable("var2", vars+1);
97  reader->AddVariable("var3", vars+2);
98  reader->AddVariable("var4", vars+3);
99  reader->BookMVA("PyKeras", "datasetTestPyKerasMulticlass/weights/testPyKerasMulticlass_PyKeras.weights.xml");
100 
101  // Get mean response of method on signal and background events
102  signal->SetBranchAddress("var1", vars+0);
103  signal->SetBranchAddress("var2", vars+1);
104  signal->SetBranchAddress("var3", vars+2);
105  signal->SetBranchAddress("var4", vars+3);
106 
107  background0->SetBranchAddress("var1", vars+0);
108  background0->SetBranchAddress("var2", vars+1);
109  background0->SetBranchAddress("var3", vars+2);
110  background0->SetBranchAddress("var4", vars+3);
111 
112  background1->SetBranchAddress("var1", vars+0);
113  background1->SetBranchAddress("var2", vars+1);
114  background1->SetBranchAddress("var3", vars+2);
115  background1->SetBranchAddress("var4", vars+3);
116 
117  background2->SetBranchAddress("var1", vars+0);
118  background2->SetBranchAddress("var2", vars+1);
119  background2->SetBranchAddress("var3", vars+2);
120  background2->SetBranchAddress("var4", vars+3);
121 
122  Float_t meanMvaSignal = 0;
123  Float_t meanMvaBackground0 = 0;
124  Float_t meanMvaBackground1 = 0;
125  Float_t meanMvaBackground2 = 0;
126  for(UInt_t i=0; i<numEvents; i++){
127  signal->GetEntry(i);
128  meanMvaSignal += reader->EvaluateMulticlass("PyKeras")[0];
129  background0->GetEntry(i);
130  meanMvaBackground0 += reader->EvaluateMulticlass("PyKeras")[1];
131  background1->GetEntry(i);
132  meanMvaBackground1 += reader->EvaluateMulticlass("PyKeras")[2];
133  background2->GetEntry(i);
134  meanMvaBackground2 += reader->EvaluateMulticlass("PyKeras")[3];
135  }
136  meanMvaSignal = meanMvaSignal/float(numEvents);
137  meanMvaBackground0 = meanMvaBackground0/float(numEvents);
138  meanMvaBackground1 = meanMvaBackground1/float(numEvents);
139  meanMvaBackground2 = meanMvaBackground2/float(numEvents);
140 
141  // Check whether the response is obviously better than guessing
142  std::cout << "Mean MVA response on signal: " << meanMvaSignal << std::endl;
143  if(meanMvaSignal < 0.3){
144  std::cout << "[ERROR] Mean response on signal is " << meanMvaSignal << " (<0.3)" << std::endl;
145  return 1;
146  }
147  std::cout << "Mean MVA response on background 0: " << meanMvaBackground0 << std::endl;
148  if(meanMvaBackground0 < 0.3){
149  std::cout << "[ERROR] Mean response on background 0 is " << meanMvaBackground0 << " (<0.3)" << std::endl;
150  return 1;
151  }
152  std::cout << "Mean MVA response on background 1: " << meanMvaBackground1 << std::endl;
153  if(meanMvaBackground0 < 0.3){
154  std::cout << "[ERROR] Mean response on background 1 is " << meanMvaBackground1 << " (<0.3)" << std::endl;
155  return 1;
156  }
157  std::cout << "Mean MVA response on background 2: " << meanMvaBackground2 << std::endl;
158  if(meanMvaBackground0 < 0.3){
159  std::cout << "[ERROR] Mean response on background 2 is " << meanMvaBackground2 << " (<0.3)" << std::endl;
160  return 1;
161  }
162 
163  return 0;
164 }
165 
166 int main(){
167  int err = testPyKerasMulticlass();
168  return err;
169 }
virtual Bool_t AccessPathName(const char *path, EAccessMode mode=kFileExists)
Returns FALSE if one can access a file using the specified access mode.
Definition: TSystem.cxx:1272
MethodBase * BookMethod(DataLoader *loader, TString theMethodName, TString methodTitle, TString theOption="")
Book a classifier or regression method.
Definition: Factory.cxx:343
float Float_t
Definition: RtypesCore.h:53
void AddVariable(const TString &expression, Float_t *)
Add a float variable or expression to the reader.
Definition: Reader.cxx:308
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:46
virtual TObject * Get(const char *namecycle)
Return pointer to object identified by namecycle.
#define gROOT
Definition: TROOT.h:375
virtual Int_t GetEntry(Long64_t entry=0, Int_t getall=0)
Read all branches of entry and return total number of bytes read.
Definition: TTree.cxx:5321
Basic string class.
Definition: TString.h:129
void TrainAllMethods()
Iterates through all booked methods and calls training.
Definition: Factory.cxx:1017
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:491
static void PyInitialize()
Initialize Python interpreter.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=1, Int_t netopt=0)
Create / open a file.
Definition: TFile.cxx:3909
const char * Data() const
Definition: TString.h:344
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString...
Definition: TString.cxx:2345
virtual Int_t SetBranchAddress(const char *bname, void *add, TBranch **ptr=0)
Change branch address, dealing with clone trees properly.
Definition: TTree.cxx:7871
IMethod * BookMVA(const TString &methodTag, const TString &weightfile)
read method name from weight file
Definition: Reader.cxx:377
int testPyKerasMulticlass()
R__EXTERN TSystem * gSystem
Definition: TSystem.h:539
TString pythonSrc
unsigned int UInt_t
Definition: RtypesCore.h:42
virtual Int_t Exec(const char *shellcmd)
Execute a command.
Definition: TSystem.cxx:660
This is the main MVA steering class.
Definition: Factory.h:81
void AddTree(TTree *tree, const TString &className, Double_t weight=1.0, const TCut &cut="", Types::ETreeType tt=Types::kMaxTreeType)
Definition: DataLoader.cxx:357
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -&gt; same cuts for signal and background
Definition: DataLoader.cxx:629
The Reader class serves to use the MVAs in a specific analysis context.
Definition: Reader.h:63
A TTree object has a header with a name and a title.
Definition: TTree.h:78
int main(int argc, char **argv)
const std::vector< Float_t > & EvaluateMulticlass(const TString &methodTag, Double_t aux=0)
evaluates MVA for given set of input variables
Definition: Reader.cxx:647