88 fDataSetManager (
NULL ),
90 fTransformations ( "
I" ),
92 fDataAssignType ( kAssignEvents ),
94 fMakeFoldDataSet ( kFALSE )
98 fLogger->SetSource(
"DataLoader");
107 std::vector<TMVA::VariableTransformBase*>::iterator trfIt =
fDefaultTrfs.begin();
108 for (;trfIt !=
fDefaultTrfs.end(); trfIt++)
delete (*trfIt);
128 return fDataSetManager->AddDataSetInfo(dsi);
135 DataSetInfo* dsi = fDataSetManager->GetDataSetInfo(dsiName);
137 if (dsi!=0)
return *dsi;
139 return fDataSetManager->AddDataSetInfo(*(
new DataSetInfo(dsiName)));
146 return DefaultDataSetInfo();
157 if (trafoDefinition.
Contains(
"(")) {
161 Ssiz_t parLen = trafoDefinition.
Index(
")", parStart )-parStart+1;
163 trName = trafoDefinition(0,parStart);
164 trOptions = trafoDefinition(parStart,parLen);
165 trOptions.
Remove(parLen-1,1);
169 trName = trafoDefinition;
173 if (trName ==
"VT") {
178 Log() << kFATAL <<
" VT transformation must be passed a floating threshold value" <<
Endl;
183 threshold = trOptions.
Atof();
186 return transformedLoader;
189 Log() << kFATAL <<
"Incorrect transformation string provided, please check" <<
Endl;
191 Log() << kINFO <<
"No transformation applied, returning original loader" <<
Endl;
205 assignTree->
Branch(
"type", &fATreeType,
"ATreeType/I" );
206 assignTree->
Branch(
"weight", &fATreeWeight,
"ATreeWeight/F" );
208 std::vector<VariableInfo>& vars = DefaultDataSetInfo().GetVariableInfos();
209 std::vector<VariableInfo>& tgts = DefaultDataSetInfo().GetTargetInfos();
210 std::vector<VariableInfo>& spec = DefaultDataSetInfo().GetSpectatorInfos();
212 if (fATreeEvent.size()==0) fATreeEvent.resize(vars.size()+tgts.size()+spec.size());
214 for (
UInt_t ivar=0; ivar<vars.size(); ivar++) {
215 TString vname = vars[ivar].GetExpression();
216 assignTree->
Branch( vname, &fATreeEvent[ivar], vname +
"/F" );
219 for (
UInt_t itgt=0; itgt<tgts.size(); itgt++) {
220 TString vname = tgts[itgt].GetExpression();
221 assignTree->
Branch( vname, &fATreeEvent[vars.size()+itgt], vname +
"/F" );
224 for (
UInt_t ispc=0; ispc<spec.size(); ispc++) {
225 TString vname = spec[ispc].GetExpression();
226 assignTree->
Branch( vname, &fATreeEvent[vars.size()+tgts.size()+ispc], vname +
"/F" );
284 const std::vector<Double_t>& event,
Double_t weight )
286 ClassInfo* theClass = DefaultDataSetInfo().AddClass(className);
295 if (clIndex>=fTrainAssignTree.size()) {
296 fTrainAssignTree.resize(clIndex+1, 0);
297 fTestAssignTree.resize(clIndex+1, 0);
300 if (fTrainAssignTree[clIndex]==0) {
301 fTrainAssignTree[clIndex] = CreateEventAssignTrees(
Form(
"TrainAssignTree_%s", className.
Data()) );
302 fTestAssignTree[clIndex] = CreateEventAssignTrees(
Form(
"TestAssignTree_%s", className.
Data()) );
305 fATreeType = clIndex;
306 fATreeWeight = weight;
307 for (
UInt_t ivar=0; ivar<
event.size(); ivar++) fATreeEvent[ivar] = event[ivar];
310 else fTestAssignTree[clIndex]->Fill();
319 return fTrainAssignTree[clIndex]!=0;
327 UInt_t size = fTrainAssignTree.size();
328 for(
UInt_t i=0; i<size; i++) {
329 if(!UserAssignEvents(i))
continue;
330 const TString& className = DefaultDataSetInfo().GetClassInfo(i)->GetName();
331 SetWeightExpression(
"weight", className );
349 Log() << kFATAL <<
"<AddTree> cannot interpret tree type: \"" << treetype
350 <<
"\" should be \"Training\" or \"Test\" or \"Training and Testing\"" <<
Endl;
352 AddTree( tree, className, weight, cut, tt );
361 Log() << kFATAL <<
"Tree does not exist (empty pointer)." <<
Endl;
363 DefaultDataSetInfo().AddClass( className );
369 Log() << kINFO<<
"Add Tree " << tree->
GetName() <<
" of type " << className
371 DataInput().AddTree( tree, className, weight, cut, tt );
379 AddTree( signal,
"Signal", weight,
TCut(
""), treetype );
388 TTree* signalTree =
new TTree(
"TreeS",
"Tree (S)" );
391 Log() << kINFO <<
"Create TTree objects from ASCII input files ... \n- Signal file : \""
395 AddTree( signalTree,
"Signal", weight,
TCut(
""), treetype );
402 AddTree( signal,
"Signal", weight,
TCut(
""), treetype );
410 AddTree( signal,
"Background", weight,
TCut(
""), treetype );
419 TTree* bkgTree =
new TTree(
"TreeB",
"Tree (B)" );
422 Log() << kINFO <<
"Create TTree objects from ASCII input files ... \n- Background file : \""
426 AddTree( bkgTree,
"Background", weight,
TCut(
""), treetype );
433 AddTree( signal,
"Background", weight,
TCut(
""), treetype );
440 AddTree( tree,
"Signal", weight );
447 AddTree( tree,
"Background", weight );
473 DataInput().AddTree( datFileS,
"Signal", signalWeight );
474 DataInput().AddTree( datFileB,
"Background", backgroundWeight );
494 DefaultDataSetInfo().AddVariable( expression, title, unit, min, max, type );
503 DefaultDataSetInfo().AddVariable( expression,
"",
"", min, max, type );
515 DefaultDataSetInfo().AddTarget( expression, title, unit, min, max );
524 DefaultDataSetInfo().AddSpectator( expression, title, unit, min, max );
532 return AddDataSet( fName );
540 for (std::vector<TString>::iterator it=theVariables->begin();
541 it!=theVariables->end(); it++) AddVariable(*it);
548 DefaultDataSetInfo().SetWeightExpression(variable,
"Signal");
555 DefaultDataSetInfo().SetWeightExpression(variable,
"Background");
564 SetSignalWeightExpression(variable);
565 SetBackgroundWeightExpression(variable);
567 else DefaultDataSetInfo().SetWeightExpression( variable, className );
573 SetCut(
TCut(cut), className );
580 DefaultDataSetInfo().SetCut( cut, className );
587 AddCut(
TCut(cut), className );
593 DefaultDataSetInfo().AddCut( cut, className );
603 SetInputTreesFromEventAssignTrees();
607 DefaultDataSetInfo().SetSplitOptions(
Form(
"nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:%s",
608 NsigTrain, NbkgTrain, NsigTest, NbkgTest, otherOpt.
Data()) );
617 SetInputTreesFromEventAssignTrees();
621 DefaultDataSetInfo().SetSplitOptions(
Form(
"nTrain_Signal=%i:nTrain_Background=%i:nTest_Signal=%i:nTest_Background=%i:SplitMode=Random:EqualTrainSample:!V",
622 Ntrain, Ntrain, Ntest, Ntest) );
631 SetInputTreesFromEventAssignTrees();
633 DefaultDataSetInfo().PrintClasses();
635 DefaultDataSetInfo().SetSplitOptions( opt );
644 SetInputTreesFromEventAssignTrees();
647 AddCut( sigcut,
"Signal" );
648 AddCut( bkgcut,
"Background" );
650 DefaultDataSetInfo().SetSplitOptions( splitOpt );
661 if(fMakeFoldDataSet){
662 Log() <<
kInfo <<
"Splitting in k-folds has been already done" <<
Endl;
666 fMakeFoldDataSet =
kTRUE;
669 const std::vector<Event*> TrainingData = DefaultDataSetInfo().GetDataSet()->GetEventCollection(
Types::kTraining);
670 const std::vector<Event*> TestingData = DefaultDataSetInfo().GetDataSet()->GetEventCollection(
Types::kTesting);
672 std::vector<Event*> TrainSigData;
673 std::vector<Event*> TrainBkgData;
674 std::vector<Event*> TestSigData;
675 std::vector<Event*> TestBkgData;
678 for(
UInt_t i=0; i<TrainingData.size(); ++i){
679 if( strncmp( DefaultDataSetInfo().GetClassInfo( TrainingData.at(i)->GetClass() )->
GetName(),
"Signal", 6) == 0){ TrainSigData.push_back(TrainingData.at(i)); }
680 else if( strncmp( DefaultDataSetInfo().GetClassInfo( TrainingData.at(i)->GetClass() )->
GetName(),
"Background", 10) == 0){ TrainBkgData.push_back(TrainingData.at(i)); }
682 Log() << kFATAL <<
"DataSets should only contain Signal and Background classes for classification, " << DefaultDataSetInfo().GetClassInfo( TrainingData.at(i)->GetClass() )->
GetName() <<
" is not a recognised class" <<
Endl;
686 for(
UInt_t i=0; i<TestingData.size(); ++i){
687 if( strncmp( DefaultDataSetInfo().GetClassInfo( TestingData.at(i)->GetClass() )->
GetName(),
"Signal", 6) == 0){ TestSigData.push_back(TestingData.at(i)); }
688 else if( strncmp( DefaultDataSetInfo().GetClassInfo( TestingData.at(i)->GetClass() )->
GetName(),
"Background", 10) == 0){ TestBkgData.push_back(TestingData.at(i)); }
690 Log() << kFATAL <<
"DataSets should only contain Signal and Background classes for classification, " << DefaultDataSetInfo().GetClassInfo( TestingData.at(i)->GetClass() )->
GetName() <<
" is not a recognised class" <<
Endl;
697 std::vector<std::vector<Event*>> tempSigEvents = SplitSets(TrainSigData,0,2);
698 std::vector<std::vector<Event*>> tempBkgEvents = SplitSets(TrainBkgData,0,2);
699 fTrainSigEvents = SplitSets(tempSigEvents.at(0),0,numberFolds);
700 fTrainBkgEvents = SplitSets(tempBkgEvents.at(0),0,numberFolds);
701 fValidSigEvents = SplitSets(tempSigEvents.at(1),0,numberFolds);
702 fValidBkgEvents = SplitSets(tempBkgEvents.at(1),0,numberFolds);
705 fTrainSigEvents = SplitSets(TrainSigData,0,numberFolds);
706 fTrainBkgEvents = SplitSets(TrainBkgData,0,numberFolds);
709 fTestSigEvents = SplitSets(TestSigData,0,numberFolds);
710 fTestBkgEvents = SplitSets(TestBkgData,0,numberFolds);
718 UInt_t numFolds = fTrainSigEvents.size();
720 std::vector<Event*>* tempTrain =
new std::vector<Event*>;
721 std::vector<Event*>* tempTest =
new std::vector<Event*>;
727 for(
UInt_t i=0; i<numFolds; ++i){
730 nTrain += fTrainSigEvents.at(i).size();
731 nTrain += fTrainBkgEvents.at(i).size();
734 nTest += fTrainSigEvents.at(i).size();
735 nTest += fTrainSigEvents.at(i).size();
740 nTrain += fValidSigEvents.at(i).size();
741 nTrain += fValidBkgEvents.at(i).size();
744 nTest += fValidSigEvents.at(i).size();
745 nTest += fValidSigEvents.at(i).size();
750 nTrain += fTestSigEvents.at(i).size();
751 nTrain += fTestBkgEvents.at(i).size();
754 nTest += fTestSigEvents.at(i).size();
755 nTest += fTestSigEvents.at(i).size();
761 tempTrain->reserve(nTrain);
762 tempTest->reserve(nTest);
765 for(
UInt_t j=0; j<numFolds; ++j){
768 tempTrain->insert(tempTrain->end(), fTrainSigEvents.at(j).begin(), fTrainSigEvents.at(j).end());
769 tempTrain->insert(tempTrain->end(), fTrainBkgEvents.at(j).begin(), fTrainBkgEvents.at(j).end());
772 tempTest->insert(tempTest->end(), fTrainSigEvents.at(j).begin(), fTrainSigEvents.at(j).end());
773 tempTest->insert(tempTest->end(), fTrainBkgEvents.at(j).begin(), fTrainBkgEvents.at(j).end());
778 tempTrain->insert(tempTrain->end(), fValidSigEvents.at(j).begin(), fValidSigEvents.at(j).end());
779 tempTrain->insert(tempTrain->end(), fValidBkgEvents.at(j).begin(), fValidBkgEvents.at(j).end());
782 tempTest->insert(tempTest->end(), fValidSigEvents.at(j).begin(), fValidSigEvents.at(j).end());
783 tempTest->insert(tempTest->end(), fValidBkgEvents.at(j).begin(), fValidBkgEvents.at(j).end());
788 tempTrain->insert(tempTrain->end(), fTestSigEvents.at(j).begin(), fTestSigEvents.at(j).end());
789 tempTrain->insert(tempTrain->end(), fTestBkgEvents.at(j).begin(), fTestBkgEvents.at(j).end());
792 tempTest->insert(tempTest->end(), fTestSigEvents.at(j).begin(), fTestSigEvents.at(j).end());
793 tempTest->insert(tempTest->end(), fTestBkgEvents.at(j).begin(), fTestBkgEvents.at(j).end());
799 DefaultDataSetInfo().GetDataSet()->SetEventCollection(tempTrain,
Types::kTraining,
false);
800 DefaultDataSetInfo().GetDataSet()->SetEventCollection(tempTest,
Types::kTesting,
false);
813 std::vector<std::vector<Event*>> tempSets;
814 tempSets.resize(numFolds);
822 if(inSet == foldSize*numFolds){
828 if(tempSets.at(s).size()<foldSize){
829 tempSets.at(s).push_back(oldSet.at(i));
858 des->
AddSignalTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
863 des->
AddBackgroundTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
872 const TMatrixD *
m = DefaultDataSetInfo().CorrelationMatrix(className);
873 return DefaultDataSetInfo().CreateCorrelationMatrixHist(m,
874 "CorrelationMatrix"+className,
"Correlation Matrix ("+className+
")");
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
std::string GetName(const std::string &scope_name)
DataSetManager * fDataSetManager
Random number generator class based on M.
MsgLogger & Endl(MsgLogger &ml)
void AddTrainingEvent(const TString &className, const std::vector< Double_t > &event, Double_t weight)
add signal training event
std::vector< TMVA::VariableTransformBase * > fDefaultTrfs
DataSetInfo & GetDataSetInfo()
Double_t Atof() const
Return floating-point value contained in string.
TTree * CreateEventAssignTrees(const TString &name)
create the data assignment tree (for event-wise data assignment by user)
DataSetInfo & DefaultDataSetInfo()
default creation
DataLoader * VarTransform(TString trafoDefinition)
Transforms the variables and return a new DataLoader with the transformed variables.
void ToLower()
Change string to lower-case.
void MakeKFoldDataSet(UInt_t numberFolds, bool validationSet=false)
Function required to split the training and testing datasets into a number of folds.
void DataLoaderCopy(TMVA::DataLoader *des, TMVA::DataLoader *src)
void SetBackgroundTree(TTree *background, Double_t weight=1.0)
DataInputHandler * fDataInputHandler
void AddBackgroundTestEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
add signal training event
TH2 * GetCorrelationMatrix(const TString &className)
returns the correlation matrix of datasets
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
const char * Data() const
Class that contains all the information of a class.
void AddTestEvent(const TString &className, const std::vector< Double_t > &event, Double_t weight)
add signal test event
void SetInputTrees(const TString &signalFileName, const TString &backgroundFileName, Double_t signalWeight=1.0, Double_t backgroundWeight=1.0)
virtual UInt_t Integer(UInt_t imax)
Returns a random integer on [ 0, imax-1 ].
void SetTree(TTree *tree, const TString &className, Double_t weight)
set background tree
void PrepareFoldDataSet(UInt_t foldNumber, Types::ETreeType tt)
Function for assigning the correct folds to the testing or training set.
Class that contains all the data information.
void SetInputVariables(std::vector< TString > *theVariables)
fill input variables in data set
DataSetInfo & AddDataSet(DataSetInfo &)
void AddCut(const TString &cut, const TString &className="")
A specialized string object used for TTree selections.
void SetInputTreesFromEventAssignTrees()
assign event-wise local trees to data set
DataInputHandler & DataInput()
Service class for 2-Dim histogram classes.
char * Form(const char *fmt,...)
Bool_t UserAssignEvents(UInt_t clIndex)
virtual const char * GetName() const
Returns name of object.
DataLoader * MakeCopy(TString name)
Copy method use in VI and CV.
TString & Remove(Ssiz_t pos)
void AddTree(TTree *tree, const TString &className, Double_t weight=1.0, const TCut &cut="", Types::ETreeType tt=Types::kMaxTreeType)
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
virtual void SetDirectory(TDirectory *dir)
Change the tree's directory.
void AddEvent(const TString &className, Types::ETreeType tt, const std::vector< Double_t > &event, Double_t weight)
add event vector event : the order of values is: variables + targets + spectators ...
Class that contains all the data information.
void SetBackgroundWeightExpression(const TString &variable)
unsigned long long ULong64_t
Bool_t IsFloat() const
Returns kTRUE if string contains a floating point or integer number.
static void DestroyInstance()
static function: destroy TMVA instance
void AddTarget(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info
void SetWeightExpression(const TString &variable, const TString &className="")
void AddBackgroundTrainingEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
add signal training event
void SetSignalWeightExpression(const TString &variable)
virtual Long64_t ReadFile(const char *filename, const char *branchDescriptor="", char delimiter= ' ')
Create or simply read branches from filename.
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
void AddSignalTestEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
add signal testing event
void AddSignalTrainingEvent(const std::vector< Double_t > &event, Double_t weight=1.0)
add signal training event
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
void SetSignalTree(TTree *signal, Double_t weight=1.0)
virtual Long64_t GetEntries() const
A TTree object has a header with a name and a title.
std::vector< std::vector< TMVA::Event * > > SplitSets(std::vector< TMVA::Event * > &oldSet, int seedNum, int numFolds)
Splits the input vector in to equally sized randomly sampled folds.
Ssiz_t Index(const char *pat, Ssiz_t i=0, ECaseCompare cmp=kExact) const
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
void SetCut(const TString &cut, const TString &className="")
void AddSpectator(const TString &expression, const TString &title="", const TString &unit="", Double_t min=0, Double_t max=0)
user inserts target in data set info