///////////////////////////////////////////////////////////////////////////////
//
//  Copyright (2008) Alexander Stukowski
//
//  This file is part of OVITO (Open Visualization Tool).
//
//  OVITO is free software; you can redistribute it and/or modify
//  it under the terms of the GNU General Public License as published by
//  the Free Software Foundation; either version 2 of the License, or
//  (at your option) any later version.
//
//  OVITO is distributed in the hope that it will be useful,
//  but WITHOUT ANY WARRANTY; without even the implied warranty of
//  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//  GNU General Public License for more details.
//
//  You should have received a copy of the GNU General Public License
//  along with this program.  If not, see <http://www.gnu.org/licenses/>.
//
///////////////////////////////////////////////////////////////////////////////

#include <scripting/Scripting.h>

namespace Scripting {

using namespace boost::python;

template<class T>
struct matrix_wrapper
{
	static FloatType get(const T& m, size_t i, size_t j) { return m(i,j); }
	static void set(T& m, size_t i, size_t j, FloatType v) { m(i,j) = v; }
};

template<class T>
struct vector_wrapper
{
	static FloatType Length(const T& v) { return Base::Length(v); }
	static FloatType LengthSquared(const T& v) { return Base::LengthSquared(v); }
	static Vector3 Normalize(const T& v) { return Base::Normalize(v); }
	static Vector3 CrossProduct(const T& v, const T& v2) { return Base::CrossProduct(v, v2); }
	static FloatType DotProduct(const T& v, const T& v2) { return Base::DotProduct(v, v2); }
	static FloatType Get(const T& v, size_t i) { return v[i]; }
	static void Set(T& v, size_t i, FloatType value) { v[i] = value; }
};

void ExportLinAlg()
{
	class_<Vector3>("Vector3", init< optional<NullVector> >())
		.def(init<FloatType, FloatType, FloatType>())
		.def_readwrite("X", &Vector3::X)
		.def_readwrite("Y", &Vector3::Y)
		.def_readwrite("Z", &Vector3::Z)
		.def(self + other<Vector3>())
		.def(self += other<Vector3>())
		.def(self - other<Vector3>())
		.def(self -= other<Vector3>())
		.def(self * FloatType())
		.def(FloatType() * self)
		.def(self *= FloatType())
		.def(self / FloatType())
		.def(self /= FloatType())
		.def(-self)
		.def(Origin() + self)
		.def(self == other<Vector3>())
		.def(self != other<Vector3>())
		.add_property("Length", &vector_wrapper<Vector3>::Length)
		.add_property("LengthSquared", &vector_wrapper<Vector3>::LengthSquared)
		.def("Normalize", &vector_wrapper<Vector3>::Normalize, return_value_policy<return_by_value>())
		.def("CrossProduct", &vector_wrapper<Vector3>::CrossProduct, return_value_policy<return_by_value>())
		.def("DotProduct", &vector_wrapper<Vector3>::DotProduct)
		.def("__len__", &Vector3::size)
		.def("__getitem__", &vector_wrapper<Vector3>::Get)
		.def("__setitem__", &vector_wrapper<Vector3>::Set)
		.def("__str__", &Vector3::toString)
	;

	class_<Point3>("Point3", init< optional<Origin> >())
		.def(init<FloatType, FloatType, FloatType>())
		.def_readwrite("X", &Point3::X)
		.def_readwrite("Y", &Point3::Y)
		.def_readwrite("Z", &Point3::Z)
		.def(self + other<Point3>())
		.def(self + other<Vector3>())
		.def(self - other<Vector3>())
		.def(self - other<Point3>())
		.def(self += other<Vector3>())
		.def(self -= other<Vector3>())
		.def(self * FloatType())
		.def(self / FloatType())
		.def(self - Origin())
		.def(self == other<Point3>())
		.def(self != other<Point3>())
		.def("__len__", &Point3::size)
		.def("__getitem__", &vector_wrapper<Point3>::Get)
		.def("__setitem__", &vector_wrapper<Point3>::Set)
		.def("__str__", &Point3::toString)
	;

	class_<Quaternion>("Quaternion", init< optional<IdentityQuaternion> >())
		.def(init<const AffineTransformation&>())
		.def_readwrite("X", &Quaternion::X)
		.def_readwrite("Y", &Quaternion::Y)
		.def_readwrite("Z", &Quaternion::Z)
		.def_readwrite("W", &Quaternion::W)
		.def("Inverse", &Quaternion::inverse, return_value_policy<return_by_value>())
		.def(self == other<Quaternion>())
		.def(self != other<Quaternion>())
		.def("__str__", &Quaternion::toString)
	;

	class_<Scaling>("Scaling", init< optional<IdentityScaling> >())
		.def("Inverse", &Scaling::inverse, return_value_policy<return_by_value>())
		.def(self == other<Scaling>())
		.def(self != other<Scaling>())
		.def("__str__", &Scaling::toString)
	;

	class_<Plane3>("Plane3", init<>())
		.def(init<const Vector3&, const FloatType&>())
		.def(init<const Point3&, const Vector3&>())
		.def(init<const Point3&, const Point3&, const Point3&>())
		.def(init<const Point3&, const Point3&, const Point3&, bool>())
		.def_readwrite("Normal", &Plane3::normal)
		.def_readwrite("Dist", &Plane3::dist)
		.def(-self)
		.def(self == other<Plane3>())
		.def("NormalizePlane", &Plane3::normalizePlane)
		.def("ClassifyPoint", &Plane3::classifyPoint)
		.def("PointDistance", &Plane3::pointDistance)
		.def("__str__", &Plane3::toString)
	;

	class_<AffineTransformation>("AffineTransformation", init<>())
		.def(init<FloatType,FloatType,FloatType,FloatType,FloatType,FloatType,FloatType,FloatType,FloatType,FloatType,FloatType,FloatType>())
		.def(init<FloatType,FloatType,FloatType,FloatType,FloatType,FloatType,FloatType,FloatType,FloatType>())
		.def(init<const Matrix3&>())
		.add_property("Determinant", &AffineTransformation::determinant)
		.add_property("Translation",
			make_function(&AffineTransformation::getTranslation, return_internal_reference<>()),
			&AffineTransformation::setTranslation)
		.def("Inverse", &AffineTransformation::inverse, return_value_policy<return_by_value>())
		.def("GetColumn", &AffineTransformation::getColumn, return_internal_reference<>())
		.def("SetColumn", &AffineTransformation::setColumn)
		.def("Get", &matrix_wrapper<AffineTransformation>::get)
		.def("Set", &matrix_wrapper<AffineTransformation>::set)
		.def(self * other<AffineTransformation>())
		.def(self * other<Point3>())
		.def(self * other<Vector3>())
		.def(self * other<Plane3>())
		.def(self * other<FloatType>())
		.def(self * other<Matrix3>())
		.def("Identity", &AffineTransformation::identity, return_value_policy<return_by_value>())
		.staticmethod("Identity")
		.def("RotationX", &AffineTransformation::rotationX, return_value_policy<return_by_value>())
		.staticmethod("RotationX")
		.def("RotationY", &AffineTransformation::rotationY, return_value_policy<return_by_value>())
		.staticmethod("RotationY")
		.def("RotationZ", &AffineTransformation::rotationZ, return_value_policy<return_by_value>())
		.staticmethod("RotationZ")
		.def("FromQuaternion", (AffineTransformation (*)(const Quaternion&))&AffineTransformation::rotation, return_value_policy<return_by_value>())
		.staticmethod("FromQuaternion")
		.def("FromRotation", (AffineTransformation (*)(const Rotation&))&AffineTransformation::rotation, return_value_policy<return_by_value>())
		.staticmethod("FromRotation")
		.def("IsoScaling", (AffineTransformation (*)(FloatType))&AffineTransformation::scaling, return_value_policy<return_by_value>())
		.staticmethod("IsoScaling")
		.def("Scaling", (AffineTransformation (*)(const Base::Scaling& scaling))&AffineTransformation::scaling, return_value_policy<return_by_value>())
		.staticmethod("Scaling")
		.def("Shear", &AffineTransformation::shear, return_value_policy<return_by_value>())
		.staticmethod("Shear")
		.def<AffineTransformation (*)(const Vector3&)>("Translation", &AffineTransformation::translation, return_value_policy<return_by_value>())
		.staticmethod("Translation")
		.def("__str__", &AffineTransformation::toString)
	;

	class_<Matrix3>("Matrix3", init<>())
		.def(init<FloatType,FloatType,FloatType,FloatType,FloatType,FloatType,FloatType,FloatType,FloatType>())
		.add_property("Determinant", &Matrix3::determinant)
		.def("Inverse", &Matrix3::inverse, return_value_policy<return_by_value>())
		.def("GetColumn", &Matrix3::getColumn, return_internal_reference<>())
		.def("SetColumn", &Matrix3::setColumn)
		.def("Get", &matrix_wrapper<Matrix3>::get)
		.def("Set", &matrix_wrapper<Matrix3>::set)
		.def(self * other<AffineTransformation>())
		.def(self * other<Matrix3>())
		.def(self * other<Point3>())
		.def(self * other<Vector3>())
		.def(self * other<FloatType>())
		.def("__str__", &Matrix3::toString)
		.def("Identity", &Matrix3::identity, return_value_policy<return_by_value>())
		.staticmethod("Identity")
		.def("RotationX", &Matrix3::rotationX, return_value_policy<return_by_value>())
		.staticmethod("RotationX")
		.def("RotationY", &Matrix3::rotationY, return_value_policy<return_by_value>())
		.staticmethod("RotationY")
		.def("RotationZ", &Matrix3::rotationZ, return_value_policy<return_by_value>())
		.staticmethod("RotationZ")
		.def("Scaling", &Matrix3::scaling, return_value_policy<return_by_value>())
		.staticmethod("Scaling")
	;

	class_<Vector4>("Vector4", init< optional<NullVector> >())
		.def(init<FloatType, FloatType, FloatType, FloatType>())
		.def("__len__", &Vector4::size)
		.def("__getitem__", &vector_wrapper<Vector4>::Get)
		.def("__str__", &Vector4::toString)
	;

	class_<Box3>("Box3", init<>())
		.def(init<const Point3&, const Point3&>())
		.def(init<const Point3&, FloatType>())
		.add_property("IsEmpty", &Box3::isEmpty)
		.add_property("Center", make_function(&Box3::center, return_value_policy<return_by_value>()))
		.add_property("Size", make_function((Vector3 (Box3::*)() const)&Box3::size, return_value_policy<return_by_value>()))
		.def("Contains", &Box3::contains)
		.def("ClassifyPoint", &Box3::classifyPoint)
		.def("__str__", &Box3::toString)
	;

	class_<Rotation>("Rotation", init< optional<NullRotation> >())
		.def(init<const Vector3&, FloatType>())
		.def(init<const AffineTransformation&>())
		.def(init<const Quaternion&>())
		.def_readwrite("Axis", &Rotation::axis)
		.def_readwrite("Angle", &Rotation::angle)
		.def("Inverse", &Rotation::inverse, return_value_policy<return_by_value>())
		.def("__str__", &Rotation::toString)
		//.def(self * other<const Rotation&>)
		.def(self += other<Rotation>())
		.def(self -= other<Rotation>())
		.def(self == other<Rotation>())
		.def(self != other<Rotation>())
	;

	class_<Color>("Color")
		.def(init<FloatType, FloatType, FloatType>())
		.def_readwrite("R", &Color::r)
		.def_readwrite("G", &Color::g)
		.def_readwrite("B", &Color::b)
		.def(self + other<Color>())
		.def(self += other<Color>())
		.def(self * other<Color>())
		.def(self * FloatType())
		.def(FloatType() * self)
		.def(self == other<Color>())
		.def(self != other<Color>())
		.def("ClampMin", &Color::clampMin)
		.def("ClampMax", &Color::clampMax)
		.def("ClampMinMax", &Color::clampMinMax)
		.def("SetWhite", &Color::setWhite)
		.def("SetBlack", &Color::setBlack)
		.def("__len__", &Color::size)
		.def("__getitem__", &vector_wrapper<Color>::Get)
		.def("__setitem__", &vector_wrapper<Color>::Set)
		.def("__str__", &Color::toString)
	;

	class_<ColorA>("ColorA")
		.def(init<FloatType, FloatType, FloatType, FloatType>())
		.def_readwrite("R", &ColorA::r)
		.def_readwrite("G", &ColorA::g)
		.def_readwrite("B", &ColorA::b)
		.def_readwrite("A", &ColorA::a)
		.def(self + other<ColorA>())
		.def(self += other<ColorA>())
		.def(self * other<ColorA>())
		.def(self * FloatType())
		.def(FloatType() * self)
		.def(self == other<ColorA>())
		.def(self != other<ColorA>())
		.def("ClampMin", &ColorA::clampMin)
		.def("ClampMax", &ColorA::clampMax)
		.def("ClampMinMax", &ColorA::clampMinMax)
		.def("SetWhite", &ColorA::setWhite)
		.def("SetBlack", &ColorA::setBlack)
		.def("__len__", &ColorA::size)
		.def("__getitem__", &vector_wrapper<ColorA>::Get)
		.def("__setitem__", &vector_wrapper<ColorA>::Set)
		.def("__str__", &ColorA::toString)
	;
}

};
