| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943 |
- #include <stdio.h>
- #include "squirrel.h"
- #include <string.h>
- #include <inttypes.h>
- #include <math.h>
- #include <stdlib.h>
- #include <sys/time.h>
- //#include <pthread.h>
- SQ_OPT_STRING_STRLEN();
- #include "floatfann.h"
- struct SQFannTrainData
- {
- fann_train_data *data;
- };
- static const SQChar sq_fann_training_data_TAG[] = _SC("SQFannTrainData");
- #define SQ_GET_FANN_TRAINING_DATA_INSTANCE_NAME_AT(v, at, vname) \
- SQ_GET_INSTANCE_VAR(v, at, SQFannTrainData, vname, sq_fann_training_data_TAG)
- #define SQ_GET_FANN_TRAINING_DATA_INSTANCE(v, at) \
- SQ_GET_FANN_TRAINING_DATA_INSTANCE_NAME_AT(v, at, self)
- static SQRESULT sq_fann_training_data_release_hook(SQUserPointer p, SQInteger size, void */*ep*/) {
- SQFannTrainData *self = (SQFannTrainData *)p;
- if(self->data) fann_destroy_train(self->data);
- return 0;
- }
- /*
- ** Creates a new fann_train_data.
- */
- static SQRESULT sq_fann_training_data_constructor (HSQUIRRELVM v) {
- SQ_FUNC_VARS_NO_TOP(v);
- fann_train_data *data = NULL;
- SQObjectType otype = sq_gettype(v, 2);
- if(otype == OT_STRING)
- {
- SQ_GET_STRING(v, 2, data_fname);
- data = fann_read_train_from_file(data_fname);
- }
- else if(otype == OT_INTEGER)
- {
- SQ_GET_INTEGER(v, 2, num_data);
- SQ_GET_INTEGER(v, 3, num_input);
- SQ_GET_INTEGER(v, 4, num_output);
- if(num_data <= 0 || num_input <= 0 || num_output <= 0)
- return sq_throwerror(v, _SC("expect only dimensions > 0"));
- data = fann_create_train(num_data, num_input, num_output);
- }
- if(!data) return sq_throwerror(v, _SC("could not create train data"));
- SQFannTrainData *self = (SQFannTrainData*)sq_malloc(sizeof(*self));
- self->data = data;
- sq_setinstanceup(v, 1, self);
- sq_setreleasehook(v, 1, sq_fann_training_data_release_hook);
- return 1;
- }
- static SQRESULT sq_fann_training_data_shuffle(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_TRAINING_DATA_INSTANCE(v, 1);
- if(self->data) fann_shuffle_train_data(self->data);
- else return sq_throwerror(v, _SC("train data not initialized"));
- return 0;
- }
- static SQRESULT sq_fann_training_data_get_errstr(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_TRAINING_DATA_INSTANCE(v, 1);
- sq_pushstring(v, fann_get_errstr((struct fann_error *)self->data), -1);
- return 1;
- }
- static SQRESULT sq_fann_training_data_get_errno(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_TRAINING_DATA_INSTANCE(v, 1);
- sq_pushinteger(v, fann_get_errno((struct fann_error *)self->data));
- return 1;
- }
- static SQRESULT sq_fann_training_data_reset_errno(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_TRAINING_DATA_INSTANCE(v, 1);
- fann_reset_errno((struct fann_error *)self->data);
- return 0;
- }
- static SQRESULT sq_fann_training_data_save(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_TRAINING_DATA_INSTANCE(v, 1);
- SQ_GET_STRING(v, 2, fname);
- sq_pushinteger(v, fann_save_train(self->data, fname));
- return 1;
- }
- #define SQ_FANN_TRAINING_DATA_GET_INT(field)\
- static SQRESULT sq_fann_training_data_##field(HSQUIRRELVM v)\
- {\
- SQ_FUNC_VARS_NO_TOP(v);\
- SQ_GET_FANN_TRAINING_DATA_INSTANCE(v, 1);\
- sq_pushinteger(v, self->data->field);\
- return 1;\
- }
- SQ_FANN_TRAINING_DATA_GET_INT(num_data);
- SQ_FANN_TRAINING_DATA_GET_INT(num_input);
- SQ_FANN_TRAINING_DATA_GET_INT(num_output);
- static SQRESULT sq_fann_training_data_set_input_at(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_TRAINING_DATA_INSTANCE(v, 1);
- SQ_GET_INTEGER(v, 2, row);
- if(row < 0 || row >= self->data->num_data) return sq_throwerror(v, _SC("index out fo bounds"));
- SQInteger cols = sq_getsize(v, 3);
- if(cols != self->data->num_input) return sq_throwerror(v, _SC("cols mismatch"));
- //fill input
- fann_type **input = self->data->input;
- for(SQInteger i=0; i < cols; ++i)
- {
- sq_arrayget(v, 3, i);
- SQ_GET_FLOAT(v, -1, value);
- input[row][i] = value;
- sq_poptop(v);
- }
- return 0;
- }
- static SQRESULT sq_fann_training_data_set_output_at(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_TRAINING_DATA_INSTANCE(v, 1);
- SQ_GET_INTEGER(v, 2, row);
- if(row < 0 || row >= self->data->num_data) return sq_throwerror(v, _SC("index out fo bounds"));
- SQInteger cols = sq_getsize(v, 3);
- if(cols != self->data->num_output) return sq_throwerror(v, _SC("cols mismatch"));
- //fill input
- fann_type **output = self->data->output;
- for(SQInteger i=0; i < cols; ++i)
- {
- sq_arrayget(v, 3, i);
- SQ_GET_FLOAT(v, -1, value);
- output[row][i] = value;
- sq_poptop(v);
- }
- return 0;
- }
- #define SCALE_ALL 0
- #define SCALE_INPUT 1
- #define SCALE_OUTPUT 2
- static SQRESULT sq_fann_training_data_scale_iot(HSQUIRRELVM v, int iot)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_TRAINING_DATA_INSTANCE(v, 1);
- SQ_GET_FLOAT(v, 2, new_min);
- SQ_GET_FLOAT(v, 3, new_max);
- switch(iot)
- {
- case SCALE_ALL:
- fann_scale_train_data(self->data, new_min, new_max);
- break;
- case SCALE_INPUT:
- fann_scale_input_train_data(self->data, new_min, new_max);
- break;
- case SCALE_OUTPUT:
- fann_scale_output_train_data(self->data, new_min, new_max);
- break;
- }
- return 0;
- }
- static SQRESULT sq_fann_training_data_scale(HSQUIRRELVM v)
- {
- return sq_fann_training_data_scale_iot(v, SCALE_ALL);
- }
- static SQRESULT sq_fann_training_data_scale_input(HSQUIRRELVM v)
- {
- return sq_fann_training_data_scale_iot(v, SCALE_INPUT);
- }
- static SQRESULT sq_fann_training_data_scale_output(HSQUIRRELVM v)
- {
- return sq_fann_training_data_scale_iot(v, SCALE_OUTPUT);
- }
- #define _DECL_FUNC(name,nparams,tycheck) {_SC(#name),sq_fann_training_data_##name,nparams,tycheck}
- static SQRegFunction sq_fann_training_data_methods[] =
- {
- _DECL_FUNC(constructor, -2,_SC("x s|i ii")),
- _DECL_FUNC(shuffle, 1,_SC("x")),
- _DECL_FUNC(get_errstr, 1,_SC("x")),
- _DECL_FUNC(get_errno, 1,_SC("x")),
- _DECL_FUNC(reset_errno, 1,_SC("x")),
- _DECL_FUNC(num_data, 1,_SC("x")),
- _DECL_FUNC(num_input, 1,_SC("x")),
- _DECL_FUNC(num_output, 1,_SC("x")),
- _DECL_FUNC(save, 2,_SC("xs")),
- _DECL_FUNC(set_input_at, 3,_SC("xia")),
- _DECL_FUNC(set_output_at, 3,_SC("xia")),
- _DECL_FUNC(scale, 3,_SC("xnn")),
- _DECL_FUNC(scale_input, 3,_SC("xnn")),
- _DECL_FUNC(scale_output, 3,_SC("xnn")),
- {0,0}
- };
- #undef _DECL_FUNC
- struct SQFann
- {
- fann *ann;
- };
- struct SQFannCallback
- {
- HSQUIRRELVM v;
- HSQOBJECT cb;
- HSQOBJECT udata;
- };
- static const SQChar sq_fann_TAG[] = _SC("SQFann");
- #define SQ_GET_FANN_INSTANCE_NAME_AT(v, at, vname) SQ_GET_INSTANCE_VAR(v, at, SQFann, vname, sq_fann_TAG)
- #define SQ_GET_FANN_INSTANCE(v, at) SQ_GET_FANN_INSTANCE_NAME_AT(v, at, self)
- static void release_sq_fann_callback(SQFannCallback *cb)
- {
- sq_release(cb->v, &cb->cb);
- sq_release(cb->v, &cb->udata);
- sq_free(cb, sizeof(*cb));
- }
- static SQRESULT sq_fann_release_hook(SQUserPointer p, SQInteger size, void */*ep*/) {
- SQFann *self = (SQFann *)p;
- if(self->ann)
- {
- SQFannCallback *cb = (SQFannCallback*)fann_get_user_data(self->ann);
- if(cb) release_sq_fann_callback(cb);
- fann_destroy(self->ann);
- }
- return 0;
- }
- /*
- ** Creates a new fann.
- */
- static SQRESULT sq_fann_constructor (HSQUIRRELVM v) {
- SQ_FUNC_VARS(v);
- fann *ann;
- SQObjectType otype = sq_gettype(v, 2);
- if(otype == OT_STRING)
- {
- SQ_GET_STRING(v, 2, net_fname);
- ann = fann_create_from_file(net_fname);
- }
- else
- {
- int create_type = FANN_NETTYPE_LAYER;
- if(_top_ > 2)
- {
- SQ_GET_INTEGER(v, 3, nt);
- create_type = nt;
- }
- switch(create_type)
- {
- case FANN_NETTYPE_LAYER:
- case FANN_NETTYPE_SHORTCUT:
- break;
- default:
- return sq_throwerror(v, _SC("invalid net type"));
- }
- SQFloat connection_rate;
- SQInteger array_pos = 2;
- if(otype == OT_FLOAT)
- {
- ++array_pos;
- SQ_GET_FLOAT(v, 2, cr);
- connection_rate = cr;
- }
- SQInteger num_layers = sq_getsize(v, array_pos);
- unsigned int *layers = (unsigned int *)
- sq_getscratchpad(v, num_layers*sizeof(unsigned int));
- for(SQInteger i=0; i < num_layers; ++i)
- {
- sq_arrayget(v, array_pos, i);
- SQ_GET_INTEGER(v, -1, value);
- layers[i] = value;
- sq_poptop(v);
- }
- if(array_pos == 2)
- {
- if(create_type == FANN_NETTYPE_LAYER)
- ann = fann_create_standard_array(num_layers, layers);
- else if(create_type == FANN_NETTYPE_SHORTCUT)
- ann = fann_create_shortcut_array(num_layers, layers);
- }
- else ann = fann_create_sparse_array(connection_rate, num_layers, layers);
- }
- if(ann)
- {
- SQFann *self = (SQFann*)sq_malloc(sizeof(*self));
- self->ann = ann;
- sq_setinstanceup(v, 1, self);
- sq_setreleasehook(v, 1, sq_fann_release_hook);
- return 1;
- }
- return sq_throwerror(v, _SC("failed to create SQFann"));
- }
- static SQRESULT sq_fann_copy(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- fann *ann = fann_copy(self->ann);
- if(ann)
- {
- SQFann *new_self = (SQFann*)sq_malloc(sizeof(*self));
- new_self->ann = ann;
- sq_pushstring(v, sq_fann_TAG, -1);
- sq_getonroottable(v);
- sq_createinstance(v, -1);
- sq_setinstanceup(v, 1, new_self);
- sq_setreleasehook(v, 1, sq_fann_release_hook);
- return 1;
- }
- return sq_throwerror(v, _SC("failed to create SQFann"));
- }
- static int sq_fann_callback_c(fann *ann, fann_train_data *train,
- unsigned int max_epochs, unsigned int epochs_between_reports,
- float desired_error, unsigned int epochs)
- {
- SQFannCallback *cb = (SQFannCallback*)fann_get_user_data(ann);
- if(cb)
- {
- /* ensure there is enough space in the stack */
- sq_reservestack(cb->v, 20);
- SQInteger top = sq_gettop(cb->v);
- sq_pushobject(cb->v, cb->cb);
- sq_pushroottable(cb->v);
- sq_pushobject(cb->v, cb->udata);
- sq_pushinteger(cb->v, max_epochs);
- sq_pushinteger(cb->v, epochs_between_reports);
- sq_pushfloat(cb->v, desired_error);
- sq_pushinteger(cb->v, epochs);
- /* call squilu function */
- SQInteger rc = 0;
- if (sq_call(cb->v, 6, SQTrue, SQFalse) == SQ_OK)
- sq_getinteger(cb->v, -1, &rc);
- sq_settop(cb->v, top);
- return rc;
- }
- return 0;
- }
- static SQRESULT sq_fann_set_callback(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- SQFannCallback *cb = (SQFannCallback*)fann_get_user_data(self->ann);
- if(cb) release_sq_fann_callback(cb);
- cb = (SQFannCallback*)sq_malloc(sizeof(*cb));
- cb->v = v;
- sq_resetobject(&cb->cb);
- sq_getstackobj(v, 2, &cb->cb);
- sq_addref(v, &cb->cb);
- sq_resetobject(&cb->udata);
- if(_top_ > 2)
- {
- sq_getstackobj(v, 3, &cb->udata);
- sq_addref(v, &cb->udata);
- }
- fann_set_user_data(self->ann, cb);
- fann_set_callback(self->ann, sq_fann_callback_c);
- return 0;
- }
- static SQRESULT sq_fann_learning_rate(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- if(_top_ > 1)
- {
- SQ_GET_FLOAT(v, 2, learn_rate);
- fann_set_learning_rate(self->ann, learn_rate);
- return 0;
- }
- sq_pushfloat(v, fann_get_learning_rate(self->ann));
- return 1;
- }
- static SQRESULT sq_fann_learning_momentum(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- if(_top_ > 1)
- {
- SQ_GET_FLOAT(v, 2, value);
- fann_set_learning_momentum(self->ann, value);
- return 0;
- }
- sq_pushfloat(v, fann_get_learning_momentum(self->ann));
- return 1;
- }
- static SQRESULT sq_fann_training_algorithm(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- if(_top_ > 1)
- {
- SQ_GET_INTEGER(v, 2, value);
- fann_set_training_algorithm(self->ann, (fann_train_enum)value);
- return 0;
- }
- sq_pushinteger(v, fann_get_training_algorithm(self->ann));
- return 1;
- }
- static SQRESULT sq_fann_train_error_function(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- if(_top_ > 1)
- {
- SQ_GET_INTEGER(v, 2, value);
- fann_set_train_error_function(self->ann, (fann_errorfunc_enum)value);
- return 0;
- }
- sq_pushinteger(v, fann_get_train_error_function(self->ann));
- return 1;
- }
- static SQRESULT sq_fann_train_stop_function(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- if(_top_ > 1)
- {
- SQ_GET_INTEGER(v, 2, value);
- fann_set_train_stop_function(self->ann, (fann_stopfunc_enum)value);
- return 0;
- }
- sq_pushinteger(v, fann_get_train_stop_function(self->ann));
- return 1;
- }
- static SQRESULT sq_fann_activation_steepness(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- if(_top_ > 3)
- {
- SQ_GET_FLOAT(v, 2, value);
- SQ_GET_INTEGER(v, 3, layer);
- SQ_GET_INTEGER(v, 4, neuron);
- fann_set_activation_steepness(self->ann, (fann_type)value, layer, neuron);
- return 0;
- }
- SQ_GET_INTEGER(v, 2, layer);
- SQ_GET_INTEGER(v, 3, neuron);
- sq_pushfloat(v, fann_get_activation_steepness(self->ann, layer, neuron));
- return 1;
- }
- #define SQFANN_GET_FLOAT_OR_INT(stype, func_name) \
- static SQRESULT sq_fann_##func_name(HSQUIRRELVM v) \
- { \
- SQ_FUNC_VARS_NO_TOP(v); \
- SQ_GET_FANN_INSTANCE(v, 1); \
- sq_push##stype(v, fann_##func_name(self->ann)); \
- return 1; \
- }
- #define SQFANN_GET_FLOAT(func_name) \
- SQFANN_GET_FLOAT_OR_INT(float, func_name)
- #define SQFANN_GET_INTEGER(func_name) \
- SQFANN_GET_FLOAT_OR_INT(integer, func_name)
- #define SQFANN_SET_FLOAT_OR_INT(stype, func_name, cast_type) \
- static SQRESULT sq_fann_##func_name(HSQUIRRELVM v) \
- { \
- SQ_FUNC_VARS_NO_TOP(v); \
- SQ_GET_FANN_INSTANCE(v, 1); \
- SQ_GET_##stype(v, 2, value); \
- fann_##func_name(self->ann, (cast_type)value); \
- return 0; \
- }
- #define SQFANN_SET_FLOAT(func_name, cast_type) \
- SQFANN_SET_FLOAT_OR_INT(FLOAT, func_name, cast_type)
- #define SQFANN_SET_INTEGER(func_name, cast_type) \
- SQFANN_SET_FLOAT_OR_INT(INTEGER, func_name, cast_type)
- SQFANN_SET_FLOAT(set_activation_steepness_hidden, fann_type);
- SQFANN_SET_FLOAT(set_activation_steepness_output, fann_type);
- SQFANN_SET_FLOAT(set_quickprop_decay, fann_type);
- SQFANN_SET_FLOAT(set_quickprop_mu, fann_type);
- SQFANN_SET_FLOAT(set_rprop_increase_factor, fann_type);
- SQFANN_SET_FLOAT(set_rprop_decrease_factor, fann_type);
- SQFANN_SET_FLOAT(set_rprop_delta_min, fann_type);
- SQFANN_SET_FLOAT(set_rprop_delta_max, fann_type);
- SQFANN_SET_FLOAT(set_cascade_output_change_fraction, fann_type);
- SQFANN_SET_INTEGER(set_cascade_output_stagnation_epochs, unsigned);
- SQFANN_SET_FLOAT(set_cascade_candidate_change_fraction, fann_type);
- SQFANN_SET_INTEGER(set_cascade_candidate_stagnation_epochs, unsigned);
- SQFANN_SET_FLOAT(set_cascade_weight_multiplier, fann_type);
- SQFANN_SET_FLOAT(set_cascade_candidate_limit, fann_type);
- SQFANN_SET_INTEGER(set_cascade_max_out_epochs, unsigned);
- SQFANN_SET_INTEGER(set_cascade_max_cand_epochs, unsigned);
- SQFANN_SET_INTEGER(set_cascade_num_candidate_groups, unsigned);
- SQFANN_SET_INTEGER(set_activation_function_hidden, fann_activationfunc_enum);
- SQFANN_SET_INTEGER(set_activation_function_output, fann_activationfunc_enum);
- static SQRESULT sq_fann_randomize_weights(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- SQ_GET_FLOAT(v, 2, min_weight);
- SQ_GET_FLOAT(v, 3, max_weight);
- fann_randomize_weights(self->ann, min_weight, max_weight);
- return 0;
- }
- static SQRESULT sq_fann_reset_MSE(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- fann_reset_MSE(self->ann);
- return 0;
- }
- SQFANN_GET_FLOAT(get_MSE);
- static SQRESULT sq_fann_get_errstr(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- sq_pushstring(v, fann_get_errstr((struct fann_error *)self->ann), -1);
- return 1;
- }
- static SQRESULT sq_fann_get_errno(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- sq_pushinteger(v, fann_get_errno((struct fann_error *)self->ann));
- return 1;
- }
- static SQRESULT sq_fann_reset_errno(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- fann_reset_errno((struct fann_error *)self->ann);
- return 0;
- }
- SQFANN_GET_INTEGER(get_num_input);
- SQFANN_GET_INTEGER(get_num_output);
- SQFANN_GET_INTEGER(get_bit_fail);
- static SQRESULT sq_fann_train_on_data_type(HSQUIRRELVM v, bool isCascade)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- SQ_GET_FANN_TRAINING_DATA_INSTANCE_NAME_AT(v, 2, data);
- SQ_GET_INTEGER(v, 3, max_epochs);
- SQ_GET_INTEGER(v, 4, epochs_between_reports);
- SQ_GET_FLOAT(v, 5, desired_error);
- if(isCascade)
- fann_cascadetrain_on_data(self->ann, data->data, max_epochs,
- epochs_between_reports, desired_error);
- else
- fann_train_on_data(self->ann, data->data, max_epochs,
- epochs_between_reports, desired_error);
- return 0;
- }
- static SQRESULT sq_fann_train_on_data(HSQUIRRELVM v)
- {
- return sq_fann_train_on_data_type(v, false);
- }
- static SQRESULT sq_fann_cascadetrain_on_data(HSQUIRRELVM v)
- {
- return sq_fann_train_on_data_type(v, true);
- }
- static SQRESULT sq_fann_train_on_file(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- SQ_GET_STRING(v, 2, data_fn);
- SQ_GET_INTEGER(v, 3, max_epochs);
- SQ_GET_INTEGER(v, 4, epochs_between_reports);
- SQ_GET_FLOAT(v, 5, desired_error);
- fann_train_on_file(self->ann, data_fn, max_epochs,
- epochs_between_reports, desired_error);
- return 0;
- }
- static SQRESULT sq_fann_test_data(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- SQ_GET_FANN_TRAINING_DATA_INSTANCE_NAME_AT(v, 2, data);
- sq_pushfloat(v, fann_test_data(self->ann, data->data));
- return 1;
- }
- static SQRESULT sq_fann_test(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- SQInteger isize = sq_getsize(v, 2);
- SQInteger ann_num_input = fann_get_num_input(self->ann);
- if(isize != ann_num_input) return sq_throwerror(v, _SC("wrong number of inputs"));
- SQInteger osize = sq_getsize(v, 3);
- SQInteger ann_num_output = fann_get_num_output(self->ann);
- if(osize != ann_num_output) return sq_throwerror(v, _SC("wrong number of outputs"));
- fann_type *data = (fann_type*)sq_getscratchpad(v, (osize+isize)*sizeof(fann_type));
- fann_type *input = data;
- fann_type *output = data+isize;
- for(SQInteger i=0; i < isize; ++i)
- {
- sq_arrayget(v, 2, i);
- SQ_GET_FLOAT(v, -1, value);
- input[i] = value;
- sq_poptop(v);
- }
- for(SQInteger i=0; i < osize; ++i)
- {
- sq_arrayget(v, 3, i);
- SQ_GET_FLOAT(v, -1, value);
- output[i] = value;
- sq_poptop(v);
- }
- fann_type *calc_output = fann_test(self->ann, input, output);
- sq_newarray(v, ann_num_output);
- for(SQInteger i=0; i < ann_num_output; ++i)
- {
- sq_pushfloat(v, calc_output[i]);
- sq_arrayset(v, -2, i);
- }
- return 1;
- }
- static SQRESULT sq_fann_run(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- SQInteger isize = sq_getsize(v, 2);
- SQInteger ann_num_input = fann_get_num_input(self->ann);
- if(isize != ann_num_input) return sq_throwerror(v, _SC("wrong number of inputs"));
- SQInteger ann_num_output = fann_get_num_output(self->ann);
- fann_type *input = (fann_type*)sq_getscratchpad(v, isize*sizeof(fann_type));
- fann_type *calc_output;
- for(SQInteger i=0; i < isize; ++i)
- {
- sq_arrayget(v, 2, i);
- SQ_GET_FLOAT(v, -1, value);
- input[i] = value;
- sq_poptop(v);
- }
- calc_output = fann_run(self->ann, input);
- if(!calc_output) return sq_throwerror(v, _SC("error running ann"));
- sq_newarray(v, ann_num_output);
- for(SQInteger i=0; i < ann_num_output; ++i)
- {
- sq_pushfloat(v, calc_output[i]);
- sq_arrayset(v, -2, i);
- }
- return 1;
- }
- static SQRESULT sq_fann_set_scaling_params(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- SQ_GET_FANN_TRAINING_DATA_INSTANCE_NAME_AT(v, 2, data);
- SQ_GET_FLOAT(v, 3, new_input_min);
- SQ_GET_FLOAT(v, 4, new_input_max);
- SQ_GET_FLOAT(v, 5, new_output_min);
- SQ_GET_FLOAT(v, 6, new_output_max);
- fann_set_scaling_params(self->ann, data->data,
- new_input_min, new_input_max, new_output_min, new_output_max);
- return 0;
- }
- static SQRESULT sq_fann_set_input_ouput_scaling_params(HSQUIRRELVM v, bool isInput)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- SQ_GET_FANN_TRAINING_DATA_INSTANCE_NAME_AT(v, 2, data);
- SQ_GET_FLOAT(v, 3, new_min);
- SQ_GET_FLOAT(v, 4, new_max);
- if(isInput) fann_set_input_scaling_params(self->ann, data->data, new_min, new_max);
- else fann_set_output_scaling_params(self->ann, data->data, new_min, new_max);
- return 0;
- }
- static SQRESULT sq_fann_set_input_scaling_params(HSQUIRRELVM v)
- {
- return sq_fann_set_input_ouput_scaling_params(v, true);
- }
- static SQRESULT sq_fann_set_output_scaling_params(HSQUIRRELVM v)
- {
- return sq_fann_set_input_ouput_scaling_params(v, false);
- }
- static SQRESULT sq_fann_init_weights(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- SQ_GET_FANN_TRAINING_DATA_INSTANCE_NAME_AT(v, 2, data);
- fann_init_weights(self->ann, data->data);
- return 0;
- }
- static SQRESULT sq_fann_scale_train(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- SQ_GET_FANN_TRAINING_DATA_INSTANCE_NAME_AT(v, 2, data);
- fann_scale_train(self->ann, data->data);
- return 0;
- }
- static SQRESULT sq_fann_scale_input(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- SQ_GET_FLOAT(v, 2, value);
- fann_type fv = value;
- fann_scale_input(self->ann, &fv);
- sq_pushfloat(v, fv);
- return 1;
- }
- static SQRESULT sq_fann_descale_output(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- SQ_GET_FLOAT(v, 2, value);
- fann_type fv = value;
- fann_descale_input(self->ann, &fv);
- sq_pushfloat(v, fv);
- return 1;
- }
- static SQRESULT sq_fann_save(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- SQ_GET_STRING(v, 2, net_fname);
- fann_save(self->ann, net_fname);
- return 0;
- }
- static SQRESULT sq_fann_print_connections(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- fann_print_connections(self->ann);
- return 0;
- }
- static SQRESULT sq_fann_print_parameters(HSQUIRRELVM v)
- {
- SQ_FUNC_VARS_NO_TOP(v);
- SQ_GET_FANN_INSTANCE(v, 1);
- fann_print_parameters(self->ann);
- return 0;
- }
- #define _DECL_FUNC(name,nparams,tycheck) {_SC(#name),sq_fann_##name,nparams,tycheck}
- static SQRegFunction sq_fann_methods[] =
- {
- _DECL_FUNC(constructor, -2,_SC("x s|f|a i")),
- _DECL_FUNC(copy, 1,_SC("x")),
- _DECL_FUNC(save, 2,_SC("xs")),
- _DECL_FUNC(learning_rate, -1,_SC("xf")),
- _DECL_FUNC(learning_momentum, -1,_SC("xf")),
- _DECL_FUNC(training_algorithm, -1,_SC("xi")),
- _DECL_FUNC(train_error_function, -1,_SC("xi")),
- _DECL_FUNC(train_stop_function, -1,_SC("xi")),
- _DECL_FUNC(set_activation_function_hidden, 2,_SC("xi")),
- _DECL_FUNC(set_activation_function_output, 2,_SC("xi")),
- _DECL_FUNC(activation_steepness, -3,_SC("xnnn")),
- _DECL_FUNC(set_activation_steepness_hidden, 2,_SC("xn")),
- _DECL_FUNC(set_activation_steepness_output, 2,_SC("xn")),
- _DECL_FUNC(set_quickprop_decay, 2,_SC("xn")),
- _DECL_FUNC(set_quickprop_mu, 2,_SC("xn")),
- _DECL_FUNC(set_rprop_increase_factor, 2,_SC("xn")),
- _DECL_FUNC(set_rprop_decrease_factor, 2,_SC("xn")),
- _DECL_FUNC(set_rprop_delta_min, 2,_SC("xn")),
- _DECL_FUNC(set_rprop_delta_max, 2,_SC("xn")),
- _DECL_FUNC(set_cascade_output_change_fraction, 2,_SC("xn")),
- _DECL_FUNC(set_cascade_output_stagnation_epochs, 2,_SC("xi")),
- _DECL_FUNC(set_cascade_candidate_change_fraction, 2,_SC("xn")),
- _DECL_FUNC(set_cascade_candidate_stagnation_epochs, 2,_SC("xi")),
- _DECL_FUNC(set_cascade_weight_multiplier, 2,_SC("xn")),
- _DECL_FUNC(set_cascade_candidate_limit, 2,_SC("xn")),
- _DECL_FUNC(set_cascade_max_out_epochs, 2,_SC("xi")),
- _DECL_FUNC(set_cascade_max_cand_epochs, 2,_SC("xi")),
- _DECL_FUNC(set_cascade_num_candidate_groups, 2,_SC("xi")),
- _DECL_FUNC(randomize_weights, 3,_SC("xnn")),
- _DECL_FUNC(reset_MSE, 1,_SC("x")),
- _DECL_FUNC(get_MSE, 1,_SC("x")),
- _DECL_FUNC(get_errstr, 1,_SC("x")),
- _DECL_FUNC(get_errno, 1,_SC("x")),
- _DECL_FUNC(reset_errno, 1,_SC("x")),
- _DECL_FUNC(get_num_input, 1,_SC("x")),
- _DECL_FUNC(get_num_output, 1,_SC("x")),
- _DECL_FUNC(get_bit_fail, 1,_SC("x")),
- _DECL_FUNC(print_connections, 1,_SC("x")),
- _DECL_FUNC(print_parameters, 1,_SC("x")),
- _DECL_FUNC(train_on_data, 5,_SC("xxiif")),
- _DECL_FUNC(cascadetrain_on_data, 5,_SC("xxiif")),
- _DECL_FUNC(train_on_file, 5,_SC("xsiif")),
- _DECL_FUNC(test, 3,_SC("xaa")),
- _DECL_FUNC(test_data, 2,_SC("xx")),
- _DECL_FUNC(run, 2,_SC("xa")),
- _DECL_FUNC(set_scaling_params, 6,_SC("xxnnnn")),
- _DECL_FUNC(set_input_scaling_params, 4,_SC("xxnn")),
- _DECL_FUNC(set_output_scaling_params, 4,_SC("xxnn")),
- _DECL_FUNC(init_weights, 2,_SC("xx")),
- _DECL_FUNC(scale_train, 2,_SC("xx")),
- _DECL_FUNC(scale_input, 2,_SC("xn")),
- _DECL_FUNC(descale_output, 2,_SC("xn")),
- _DECL_FUNC(set_callback, -2,_SC("xc.")),
- {0,0}
- };
- #undef _DECL_FUNC
- typedef struct {
- const SQChar *Str;
- SQInteger Val;
- } KeyIntType, * KeyIntPtrType;
- static KeyIntType module_constants[] = {
- #define MK_CONST(c) {_SC(#c), FANN_##c}
- MK_CONST(LINEAR),
- MK_CONST(THRESHOLD),
- MK_CONST(THRESHOLD_SYMMETRIC),
- MK_CONST(SIGMOID),
- MK_CONST(SIGMOID_STEPWISE),
- MK_CONST(SIGMOID_SYMMETRIC),
- MK_CONST(SIGMOID_SYMMETRIC_STEPWISE),
- MK_CONST(GAUSSIAN),
- MK_CONST(GAUSSIAN_STEPWISE),
- MK_CONST(ELLIOT),
- MK_CONST(ELLIOT_SYMMETRIC),
- MK_CONST(GAUSSIAN_SYMMETRIC),
- MK_CONST(LINEAR_PIECE),
- MK_CONST(LINEAR_PIECE_SYMMETRIC),
- MK_CONST(SIN_SYMMETRIC),
- MK_CONST(COS_SYMMETRIC),
- MK_CONST(SIN),
- MK_CONST(COS),
- MK_CONST(TRAIN_INCREMENTAL),
- MK_CONST(TRAIN_BATCH),
- MK_CONST(TRAIN_RPROP),
- MK_CONST(TRAIN_QUICKPROP),
- MK_CONST(TRAIN_SARPROP),
- MK_CONST(ERRORFUNC_LINEAR),
- MK_CONST(ERRORFUNC_TANH),
- MK_CONST(STOPFUNC_MSE),
- MK_CONST(STOPFUNC_BIT),
- MK_CONST(NETTYPE_LAYER),
- MK_CONST(NETTYPE_SHORTCUT),
- {0,0}
- };
- #ifdef __cplusplus
- extern "C" {
- #endif
- /* This defines a function that opens up your library. */
- SQRESULT sqext_register_fann (HSQUIRRELVM v) {
- sq_pushstring(v,sq_fann_training_data_TAG,-1);
- sq_newclass(v,SQFalse);
- sq_settypetag(v,-1,(void*)sq_fann_training_data_TAG);
- sq_insert_reg_funcs(v, sq_fann_training_data_methods);
- sq_newslot(v,-3,SQTrue);
- sq_pushstring(v,sq_fann_TAG,-1);
- sq_newclass(v,SQFalse);
- sq_settypetag(v,-1,(void*)sq_fann_TAG);
- sq_insert_reg_funcs(v, sq_fann_methods);
- //add constants
- KeyIntPtrType KeyIntPtr;
- for (KeyIntPtr = module_constants; KeyIntPtr->Str; KeyIntPtr++) {
- sq_pushstring(v, KeyIntPtr->Str, -1); //first the key
- sq_pushinteger(v, KeyIntPtr->Val); //then the value
- sq_newslot(v, -3, SQFalse); //store then
- }
- sq_newslot(v,-3,SQTrue);
- return SQ_OK;
- }
- #ifdef __cplusplus
- }
- #endif
|