#include "scene.h"
#include "image.h"
#include "rend.h"

#define MAX_RAY_DEPTH	6

static void rand_dir(float *x, float *y, float *z, unsigned int *seedp);

static float fresnel(const Vector3 &inc, const Vector3 &trans, const Vector3 &norm, float ior_inc, float ior_trans);
static float fresnel(float cos_inc, float cos_trans, float ior_inc, float ior_trans);


static Image *frame;
static Scene *scn;

static int num_samples = 1;

bool init_renderer(Scene *s, int xsz, int ysz)
{
	delete frame;

	scn = s;

	try {
		frame = new Image(xsz, ysz);
	}
	catch(...) {
		return false;
	}
	return true;
}

void destroy_renderer()
{
	delete frame;
}

void set_render_samples(int samples)
{
	num_samples = samples;
}

void resize_renderer(int xsz, int ysz)
{
	frame->set_pixels(xsz, ysz, 0);
}

float *render_frame(long msec)
{
	scn->prepare_xform(msec);

	Camera *cam = scn->get_camera();

	unsigned int rseed = 1;

	for(int i=0; i<frame->ysz; i++) {
		Color *pixel = frame->pixels + i * frame->xsz;

		for(int j=0; j<frame->xsz; j++) {
			Ray ray = cam->get_primary_ray(j, i, frame->xsz, frame->ysz, 0);
			ray.time = msec;
			ray.user = &rseed;

			*pixel++ = trace_ray(scn, ray);
		}
	}

	return (float*)frame->pixels;
}

Color trace_ray(const Scene *scn, const Ray &ray, int rdepth)
{
	HitPoint hit;

	if(scn->intersect(ray, &hit)) {
		float t;
		if(scn->fog_start >= 0.0 && (t = (hit.dist - scn->fog_start) / (scn->fog_end - scn->fog_start)) > 0.0) {
			return lerp(shade(scn, ray, hit, rdepth), scn->env_color(ray), t > 1.0 ? 1.0 : t);
		}
		return shade(scn, ray, hit, rdepth);
	}

	return scn->env_color(ray);
}

Color shade(const Scene *scn, const Ray &ray, const HitPoint &hit, int rdepth)
{
	const Material *mat = &hit.obj->material;

	// if we're leaving the object, we need to invert the normal (and ior)
	Vector3 normal;
	bool entering;
	if(dot_product(hit.normal, ray.dir) <= 0.0) {
		normal = hit.normal;
		entering = true;
	} else {
		normal = -hit.normal;
		entering = false;
	}

	Vector3 vdir = -ray.dir;

	Color diffuse_color = mat->diffuse;
	Color tex_color{1, 1, 1};
	if(mat->tex) {
		tex_color *= mat->tex->sample(hit);
		diffuse_color *= tex_color;
	}

	Color color = mat->emission * tex_color;

	// image-based lighting
	if(scn->envmap_conv) {
		// pick a random direction and create a sampling ray
		Ray envray;
		envray.origin = hit.pos;
		rand_dir(&envray.dir.x, &envray.dir.y, &envray.dir.z, (unsigned int*)ray.user);
		if(dot_product(envray.dir, normal) < 0.0) {
			envray.dir = -envray.dir;
		}

		HitPoint env_hit;
		if(!scn->intersect(envray, &env_hit)) {
			Vector3 dir = envray.dir;
			color += scn->envmap_conv->sample(dir.x, dir.y, dir.z) * diffuse_color;
		}
	}

	for(Light *lt: scn->lights) {

		/* construct a shadow ray to determine if there is an uninterrupted
		 * path between the intersection point and the light source
		 */
		Ray shadow_ray = ray;
		shadow_ray.origin = hit.pos;
		shadow_ray.dir = lt->pos - hit.pos;

		/* the interval [0, 1] represents the part of the ray from the origin
		 * to the light. We don't care about intersections behind the origin
		 * of the shadow ray (behind the surface of the object), or after the
		 * light source. We only care if there's something in between hiding the
		 * light.
		 */
		HitPoint shadow_hit;
		if(scn->intersect(shadow_ray, &shadow_hit) && shadow_hit.dist < 1.0f) {
			continue;	// skip this light, it's hidden from view
		}

		// calculate the light direction
		Vector3 ldir = shadow_ray.dir.normalized();
		// calculate the reflected light direction
		Vector3 lref = ldir.reflection(normal);

		float diffuse = std::max(dot_product(ldir, normal), 0.0f);
		float specular = pow(std::max(dot_product(lref, vdir), 0.0f), mat->shininess);

		color += (diffuse_color * diffuse + mat->specular * specular) * lt->color;
	}

	Color spec_col;

	if(mat->reflectivity > 0.001f && rdepth < MAX_RAY_DEPTH) {
		Ray refl_ray{ray};
		refl_ray.origin = hit.pos;
		refl_ray.dir = -ray.dir.reflection(normal);

		spec_col += trace_ray(scn, refl_ray, rdepth + 1) * mat->reflectivity;
	}

	/*if(mat->transparency > 0.001f && rdepth < MAX_RAY_DEPTH) {
		float from_ior = entering ? 1.0 : mat->ior;
		float to_ior = entering ? mat->ior : 1.0;

		Ray refr_ray{ray};
		refr_ray.origin = hit.pos;
		refr_ray.dir = ray.dir.refraction(normal, from_ior / to_ior);

		Color tcol = trace_ray(scn, refr_ray, rdepth + 1) * mat->transparency;

		float fres = fresnel(ray.dir, refr_ray.dir, normal, from_ior, to_ior);
		spec_col = spec_col * fres + tcol * (1.0 - fres);
	}*/

	return color + spec_col;
}


static void rand_dir(float *x, float *y, float *z, unsigned int *seedp)
{
	float u = (float)rand_r(seedp) / RAND_MAX;
	float v = (float)rand_r(seedp) / RAND_MAX;

	float theta = 2.0 * M_PI * u;
	float phi = acos(2.0 * v - 1.0);

	*x = cos(theta) * sin(phi);
	*y = sin(theta) * sin(phi);
	*z = cos(phi);
}

static float fresnel(const Vector3 &inc, const Vector3 &trans, const Vector3 &norm, float ior_inc, float ior_trans)
{
	float cos_inc = dot_product(-inc, norm);
	float cos_trans = dot_product(-trans, norm);

	return fresnel(cos_inc, cos_trans, ior_inc, ior_trans);
}

static float fresnel(float cos_inc, float cos_trans, float ior_inc, float ior_trans)
{
	float r0 = ((ior_trans * cos_inc) - (ior_inc * cos_trans)) /
		((ior_trans * cos_inc) + (ior_inc * cos_trans));
	float r1 = ((ior_inc * cos_inc) - (ior_trans * cos_trans)) /
		((ior_inc * cos_inc) + (ior_trans * cos_trans));
	return (r0 * r0 + r1 * r1) * 0.5f;
}
