#include <stdio.h>
#include <math.h>
#include "camera.h"

static void calc_sample_pos_rec(int sidx, float xsz, float ysz, float *pos);

Camera::Camera()
{
	vfov = M_PI / 4.0;
	cached_matrix_valid = false;

	rdir_cache_width = rdir_cache_height = 0;
	rdir_cache = 0;
}

Camera::Camera(const Vector3 &p)
	: pos(p)
{
	vfov = M_PI / 4.0;
	cached_matrix_valid = false;

	rdir_cache_width = rdir_cache_height = 0;
	rdir_cache = 0;
}

Camera::~Camera()
{
	delete [] rdir_cache;
}

void Camera::set_fov(float vfov)
{
	this->vfov = vfov;

	// invalidate the dir cache
	delete [] rdir_cache;
}

float Camera::get_fov() const
{
	return vfov;
}

void Camera::set_position(const Vector3 &pos)
{
	this->pos = pos;
	cached_matrix_valid = false;	// invalidate the cached matrix
}

const Vector3 &Camera::get_position() const
{
	return pos;
}

const Matrix4x4 &Camera::get_matrix() const
{
	if(!cached_matrix_valid) {
		calc_matrix(&cached_matrix);
		cached_matrix_valid = true;
	}
	return cached_matrix;
}

Vector2 Camera::calc_sample_pos(int x, int y, int xsz, int ysz, int sample) const
{
	float ppos[2];
	float aspect = (float)xsz / (float)ysz;

	float pwidth = 2.0 * aspect / (float)xsz;
	float pheight = 2.0 / (float)ysz;

	ppos[0] = (float)x * pwidth - aspect;
	ppos[1] = 1.0 - (float)y * pheight;

	calc_sample_pos_rec(sample, pwidth, pheight, ppos);
	return Vector2(ppos[0], ppos[1]);
}

Ray Camera::get_primary_ray(int x, int y, int xsz, int ysz, int sample) const
{
	if(!rdir_cache || rdir_cache_width != xsz || rdir_cache_height != ysz) {
		printf("calculating primary ray direction cache\n");

		delete [] rdir_cache;
		rdir_cache = new Vector3[xsz * ysz];

		for(int i=0; i<ysz; i++) {
			Vector3 *rdir = rdir_cache + i * xsz;
			for(int j=0; j<xsz; j++) {
				Vector2 ppos = calc_sample_pos(j, i, xsz, ysz, 0);

				rdir->x = ppos.x;
				rdir->y = ppos.y;
				rdir->z = 1.0 / tan(vfov / 2.0);
				rdir->normalize();

				rdir++;
			}
		}
		rdir_cache_width = xsz;
		rdir_cache_height = ysz;
	}

	Ray ray;
	ray.origin = pos;
	ray.dir = rdir_cache[y * xsz + x];

	// transform the ray direction with the camera matrix
	Matrix4x4 mat = get_matrix();
	mat.m[0][3] = mat.m[1][3] = mat.m[2][3] = mat.m[3][0] = mat.m[3][1] = mat.m[3][2] = 0.0;
	mat.m[3][3] = 1.0;

	ray.dir = ray.dir.transformed(mat);
	return ray;
}

TargetCamera::TargetCamera() {}

TargetCamera::TargetCamera(const Vector3 &pos, const Vector3 &targ)
	: Camera(pos), target(targ)
{
}

void TargetCamera::set_target(const Vector3 &targ)
{
	target = targ;
	cached_matrix_valid = false; // invalidate the cached matrix
}

const Vector3 &TargetCamera::get_target() const
{
	return target;
}

void TargetCamera::calc_matrix(Matrix4x4 *mat) const
{
	Vector3 up(0, 1, 0);
	Vector3 dir = (target - pos).normalized();
	Vector3 right = cross_product(up, dir);
	up = cross_product(dir, right);

	*mat = Matrix4x4(
			right.x, up.x, dir.x, pos.x,
			right.y, up.y, dir.y, pos.y,
			right.z, up.z, dir.z, pos.z,
			0.0, 0.0, 0.0, 1.0);
}

void FlyCamera::input_move(float x, float y, float z)
{
	static const Vector3 vfwd(0, 0, 1), vright(1, 0, 0);

	Vector3 k = vfwd.transformed(rot);
	Vector3	i = vright.transformed(rot);
	Vector3 j = cross_product(k, i);

	pos += i * x + j * y + k * z;
	cached_matrix_valid = false;
}

void FlyCamera::input_rotate(float x, float y, float z)
{
	Vector3 axis(x, y, z);
	float axis_len = axis.length();
	if(fabs(axis_len) < 1e-5) {
		return;
	}
	rot.rotate(axis / axis_len, -axis_len);
	rot.normalize();

	cached_matrix_valid = false;
}

void FlyCamera::calc_matrix(Matrix4x4 *mat) const
{
	Matrix3x3 rmat = rot.get_rotation_matrix();
	*mat = rmat;
}

/* generates a sample position for sample number sidx, in the unit square
 * by recursive subdivision and jittering
 */
static void calc_sample_pos_rec(int sidx, float xsz, float ysz, float *pos)
{
    static const float subpt[4][2] = {
        {-0.25, -0.25}, {0.25, -0.25}, {-0.25, 0.25}, {0.25, 0.25}
    };

    if(!sidx) {
        return;
    }

    /* determine which quadrant to recurse into */
    int quadrant = ((sidx - 1) % 4);
    pos[0] += subpt[quadrant][0] * xsz;
    pos[1] += subpt[quadrant][1] * ysz;

    calc_sample_pos_rec((sidx - 1) / 4, xsz / 2, ysz / 2, pos);
}
