Program Listing for File LoaderKeypointData.h¶
↰ Return to documentation for file (larflow/KeyPoints/LoaderKeypointData.h
)
#ifndef __LOADER_KEYPOINT_DATA_H__
#define __LOADER_KEYPOINT_DATA_H__
#include <Python.h>
#include "bytesobject.h"
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <numpy/ndarrayobject.h>
#include <string>
#include <vector>
#include "TChain.h"
#include "larflow/PrepFlowMatchData/PrepMatchTriplets.h"
namespace larflow {
namespace keypoints {
class LoaderKeypointData {
public:
LoaderKeypointData()
: ttriplet(nullptr),
tkeypoint(nullptr),
tssnet(nullptr),
_exclude_neg_examples(true)
{};
LoaderKeypointData( std::vector<std::string>& input_v );
virtual ~LoaderKeypointData();
std::vector<std::string> input_files;
void add_input_file( std::string input ) { input_files.push_back(input); };
TChain* ttriplet;
TChain* tkeypoint;
TChain* tssnet;
std::vector<larflow::prep::PrepMatchTriplets>* triplet_v;
std::vector< std::vector<float> >* kplabel_v[3];
std::vector< std::vector<float> >* kpshift_v;
std::vector< int >* ssnet_label_v;
std::vector< float >* ssnet_weight_v;
void exclude_false_triplets( bool exclude ) { _exclude_neg_examples = exclude; };
void load_tree();
unsigned long load_entry( int entry );
unsigned long GetEntries();
PyObject* sample_data( const int& num_max_samples,
int& nfilled,
bool withtruth );
protected:
int make_ssnet_arrays( const int& num_max_samples,
int& nfilled,
bool withtruth,
std::vector<int>& pos_match_index,
PyArrayObject* match_array,
PyArrayObject*& ssnet_label,
PyArrayObject*& ssnet_top_weight,
PyArrayObject*& ssnet_class_weight );
int make_kplabel_arrays( const int& num_max_samples,
int& nfilled,
bool withtruth,
std::vector<int>& pos_match_index,
PyArrayObject* match_array,
PyArrayObject*& kplabel_label,
PyArrayObject*& kplabel_weight );
int make_kpshift_arrays( const int& num_max_samples,
int& nfilled,
bool withtruth,
PyArrayObject* match_array,
PyArrayObject*& kpshift_label );
static bool _setup_numpy;
bool _exclude_neg_examples;
};
}
}
#endif