remove Vectorize, simplify to just numpy.squeeze
parent
89bc31d703
commit
0e278f81c6
|
@ -420,12 +420,6 @@ void Module::emit_cython_pyx(FileWriter& pyxFile) const {
|
||||||
"from libcpp.pair cimport pair\n"
|
"from libcpp.pair cimport pair\n"
|
||||||
"from libcpp.string cimport string\n"
|
"from libcpp.string cimport string\n"
|
||||||
"from cython.operator cimport dereference as deref\n\n\n";
|
"from cython.operator cimport dereference as deref\n\n\n";
|
||||||
pyxFile.oss <<
|
|
||||||
R"rawstr(def Vectorize(*args):
|
|
||||||
ret = npp.squeeze(npp.asarray(args, dtype='float'))
|
|
||||||
if ret.ndim == 0: ret = npp.expand_dims(ret, axis=0)
|
|
||||||
return ret
|
|
||||||
)rawstr";
|
|
||||||
|
|
||||||
// all classes include all forward declarations
|
// all classes include all forward declarations
|
||||||
std::vector<Class> allClasses = expandedClasses;
|
std::vector<Class> allClasses = expandedClasses;
|
||||||
|
|
|
@ -94,7 +94,7 @@ std::string ReturnType::pyx_casting(const std::string& var,
|
||||||
if (isEigen()) {
|
if (isEigen()) {
|
||||||
string s = "ndarray_copy(" + var + ")";
|
string s = "ndarray_copy(" + var + ")";
|
||||||
if (pyxClassName() == "Vector")
|
if (pyxClassName() == "Vector")
|
||||||
return "Vectorize(" + s + ")";
|
return s + ".squeeze()";
|
||||||
else return s;
|
else return s;
|
||||||
}
|
}
|
||||||
else if (isNonBasicType()) {
|
else if (isNonBasicType()) {
|
||||||
|
|
Loading…
Reference in New Issue