import React, { useEffect, useRef, useState } from 'react';
import * as THREE from 'three';
import { OrbitControls } from 'three/examples/jsm/controls/OrbitControls';
import { FontLoader } from 'three/examples/jsm/loaders/FontLoader';
import { TextGeometry } from 'three/examples/jsm/geometries/TextGeometry';

// Constants
const maxIterations = 1000;
const learningRate = 0.01;

// Function to calculate MSE
function calculateMSE(w, b) {
  return dataPoints.reduce((sum, point) => sum + (point.y - (w * point.x + b)) ** 2, 0) / dataPoints.length;
}

// Function to calculate gradients for gradient descent
function calculateGradients(w, b) {
  const m = dataPoints.length;
  let dw = 0;
  let db = 0;
  dataPoints.forEach(point => {
    const prediction = w * point.x + b;
    dw += (prediction - point.y) * point.x;
    db += (prediction - point.y);
  });
  return { dw: (2 / m) * dw, db: (2 / m) * db };
}

// Sample data points for calculating MSE
const dataPoints = [
  { x: -5, y: -3 }, { x: -4, y: -1 }, { x: -3, y: 0 }, { x: -2, y: 1 },
  { x: -1, y: 1.5 }, { x: 0, y: 2 }, { x: 1, y: 3 }, { x: 2, y: 4 },
  { x: 3, y: 5 }, { x: 4, y: 6 }
];

// Cost function calculation for the surface
function costFunction(x, y, m, b) {
  return (m * x + b - y) ** 2;
}

// 3D visualization setup for the cost function
function initCostFunction3D(containerId, slope, intercept, mse, setCostTextMesh) {
  const container = document.getElementById(containerId);

  if (!container) {
    console.error(`Container with ID ${containerId} not found.`);
    return;
  }

  // Clear previous render if reinitializing
  container.innerHTML = '';

  // Set up scene, camera, and renderer
  const scene = new THREE.Scene();
  const camera = new THREE.PerspectiveCamera(75, container.clientWidth / container.clientHeight, 0.1, 1000);
  camera.position.set(0, 15, 25);

  const renderer = new THREE.WebGLRenderer({ antialias: true });
  renderer.setSize(container.clientWidth, container.clientHeight);
  container.appendChild(renderer.domElement);

  // Coordinates for the intersection point on the 2D slope-intercept plane
  const intersectionPoint = new THREE.Vector3(slope, intercept, 0);

  // Correctly set up line geometry for the slope and intercept lines across the plane
  const slopeLineGeometry = new THREE.BufferGeometry().setFromPoints([
    new THREE.Vector3(slope, 0, -10), // Start at min intercept (y-axis)
    new THREE.Vector3(slope, 0,10)   // End at max intercept (y-axis)
  ]);

  const interceptLineGeometry = new THREE.BufferGeometry().setFromPoints([
    new THREE.Vector3(-10, 0, intercept), // Start at min slope (x-axis)
    new THREE.Vector3(10, 0, intercept)   // End at max slope (x-axis)
  ]);

  const lineMaterial = new THREE.LineBasicMaterial({ color: 0x0000ff });

  const slopeLine = new THREE.Line(slopeLineGeometry, lineMaterial);
  scene.add(slopeLine);

  const interceptLine = new THREE.Line(interceptLineGeometry, lineMaterial);
  scene.add(interceptLine);

  // Plot the current MSE position as a sphere on the 3D graph
  const xMarkerGeometry = new THREE.SphereGeometry(0.2, 32, 32);
  const xMarkerMaterial = new THREE.MeshStandardMaterial({ color: 0xff0000 });
  const xMarker = new THREE.Mesh(xMarkerGeometry, xMarkerMaterial);

  xMarker.position.set(slope, mse, intercept);
  scene.add(xMarker);

  

  // Add AxesHelper for orientation
  const axesHelper = new THREE.AxesHelper(10);
  scene.add(axesHelper);

  // Lighting
  const ambientLight = new THREE.AmbientLight(0xffffff, 0.8);
  scene.add(ambientLight);

 // Set up grid helpers
 const gridHelperXY = new THREE.GridHelper(20, 30, 0x808080, 0x808080);
 gridHelperXY.rotation.x = Math.PI / 2;
 gridHelperXY.rotation.z = Math.PI / 2;
 gridHelperXY.position.set(-10, 10, 0);
 //gridHelperXY.scale.set(0, 10, -10);
 scene.add(gridHelperXY);

 const gridHelperXYagain = new THREE.GridHelper(20, 30, 0x808080, 0x808080);
 gridHelperXYagain.rotation.x = Math.PI / 2;
 gridHelperXYagain.position.set(0, 10, -10);
 //gridHelperXYagain.scale.set(0, 10, -10);
 scene.add(gridHelperXYagain);

 const gridHelperXZ = new THREE.GridHelper(20, 30, 0xff7700, 0xff7700);
 gridHelperXZ.position.set(0, 0, 0);
 scene.add(gridHelperXZ);
  scene.background = new THREE.Color(0xffffff);
  const fontLoader = new FontLoader();

// Add numbered grid function
function addNumberedGrid(scene) {
  const fontLoader = new FontLoader();
  fontLoader.load('https://threejs.org/examples/fonts/helvetiker_regular.typeface.json', (font) => {
    const textMaterial = new THREE.MeshBasicMaterial({ color: 0x000000 });
    
    // Helper function to add text labels
    const addTextLabel = (text, x, y, z) => {
      const textGeometry = new TextGeometry(text, {
        font: font,
        size: 0.4,
        height: 0.05,
      });
      const textMesh = new THREE.Mesh(textGeometry, textMaterial);
      textMesh.position.set(x, y, z);
      scene.add(textMesh);
    };

    // Add labels along the X-axis
    for (let i = -10; i <= 10; i += 2) {
      addTextLabel(i.toString(), i, 0, 10); // X-axis for gridHelperXY
    }

    // Add labels along the Y-axis (only positive numbers up to 20)
    for (let i = 0; i <= 20; i += 2) {
      addTextLabel(i.toString(), -10, i, 10); // Y-axis for gridHelperXY
    }

    // Add labels along the Z-axis
    for (let i = -10; i <= 10; i += 2) {
      addTextLabel(i.toString(), 10, 0, i); // Z-axis for gridHelperXZ
    }

    // Add axis labels slightly outside the grid
    addTextLabel("w", 0, 0, 13); // X-axis label
    addTextLabel("b", 13, 0, 0); // Y-axis label
    addTextLabel("J(w,b)", -10, 10, 13); // Z-axis label
  });
}


// Call the addNumberedGrid function
addNumberedGrid(scene);

  // Initialize cost function surface with normal size
  const geometry = new THREE.PlaneGeometry(20, 20, 30, 30); // 20x20 to match grid dimensions
  const material = new THREE.MeshBasicMaterial({ color: 0x0077ff, wireframe: true });
  const costSurface = new THREE.Mesh(geometry, material);
  costSurface.rotation.x = -Math.PI / 2;
  costSurface.position.set(0, 0, 0); // No scaling
  scene.add(costSurface);

  // Update cost surface based on cost function
  const updateCostSurface = () => {
    const positions = geometry.attributes.position;
    for (let i = 0; i < positions.count; i++) {
      const x = positions.getX(i);
      const y = positions.getY(i) +4 ;
      let z = dataPoints.reduce((sum, point) => sum + costFunction(point.x, point.y, x, y), 0) / dataPoints.length;
      z = z * 0.1;
      positions.setZ(i, z);
    }
    positions.needsUpdate = true;
  };

  updateCostSurface();

  // Dynamic cost function text
  fontLoader.load('https://threejs.org/examples/fonts/helvetiker_regular.typeface.json', (font) => {
    const costTextMaterial = new THREE.MeshBasicMaterial({ color: 0x000000 });
    const costTextMesh = new THREE.Mesh(
      new TextGeometry(`J(w, b) = ${mse.toFixed(2)}`, { font, size: 0.5, height: 0.05 }),
      costTextMaterial
    );
    costTextMesh.position.set(-8, 10, 10);
    scene.add(costTextMesh);
    setCostTextMesh(costTextMesh);
  });

  const controls = new OrbitControls(camera, renderer.domElement);
  controls.enableDamping = true;

  function animate() {
    requestAnimationFrame(animate);
    controls.update();
    renderer.render(scene, camera);
  }
  animate();

  window.addEventListener('resize', () => {
    const width = container.clientWidth;
    const height = container.clientHeight;
    camera.aspect = width / height;
    camera.updateProjectionMatrix();
    renderer.setSize(width, height);
  });
}

// Main React component
const MachineLearningVisualization = () => {
  const [slope, setSlope] = useState(1);
  const [intercept, setIntercept] = useState(0);
  const [mse, setMSE] = useState(0);
  const [costTextMesh, setCostTextMesh] = useState(null);
  const [learningRate, setLearningRate] = useState(0.01);
  const [isRunning, setIsRunning] = useState(false);
  const [intervalId, setIntervalId] = useState(null);
  const [xAxisLabel, setXAxisLabel] = useState("Dose");
const [yAxisLabel, setYAxisLabel] = useState("Effect");
const [dataPoints, setDataPoints] = useState([
  { x: -5, y: -3 }, { x: -4, y: -1 }, { x: -3, y: 0 }, { x: -2, y: 1 },
  { x: -1, y: 1.5 }, { x: 0, y: 2 }, { x: 1, y: 3 }, { x: 2, y: 4 },
  { x: 3, y: 5 }, { x: 4, y: 6 }
]);


  const canvasRef = useRef(null);
  const performGradientDescent = () => {
    let w = slope;
    let b = intercept;
    let currentIteration = 0;
  
    // Set up an interval for stepwise gradient descent
    const intervalId = setInterval(() => {
      if (currentIteration >= maxIterations) {
        clearInterval(intervalId); // Stop the interval after reaching max iterations
        return;
      }
  
      // Calculate gradients for the current values of w and b
      const { dw, db } = calculateGradients(w, b);
      w -= learningRate * dw; // Update slope (w) with gradient descent step
      b -= learningRate * db; // Update intercept (b) with gradient descent step
  
      // Optional: Stop early if gradients are very small
      if (Math.abs(dw) < 1e-6 && Math.abs(db) < 1e-6) {
        clearInterval(intervalId);
        return;
      }
// Update state variables to trigger re-renders
setSlope(w);
setIntercept(b);
setMSE(calculateMSE(w, b)); currentIteration++; // Increment iteration count
}, 500); // Run every 100 ms, adjust for slower or faster visualization
};  
  useEffect(() => {
    const calculateMSE = () => {
      return (
        dataPoints.reduce((sum, point) => sum + (point.y - (slope * point.x + intercept)) ** 2, 0) /
        dataPoints.length
      );
    };

    const mseValue = calculateMSE();
    setMSE(mseValue);

    initCostFunction3D('ml-visualization-container', slope, intercept, mseValue, setCostTextMesh);
    drawLineAndResidualsOnCanvas();
  }, [slope, intercept]);

  useEffect(() => {
    if (costTextMesh) {
      const newTextGeometry = new TextGeometry(`J(w, b) = ${mse.toFixed(2)}`, {
        font: costTextMesh.geometry.parameters.font,
        size: 0.5,
        height: 0.05,
      });
      costTextMesh.geometry.dispose();
      costTextMesh.geometry = newTextGeometry;
    }
  }, [mse]);

  // Add new data point
  const addDataPoint = (x, y) => {
    setDataPoints([...dataPoints, { x: parseFloat(x), y: parseFloat(y) }]);
  };

  const drawLineAndResidualsOnCanvas = () => {
    const canvas = canvasRef.current;
    const ctx = canvas.getContext('2d');
    
    // Clear canvas
    ctx.clearRect(0, 0, canvas.width, canvas.height);
  
    // Find data bounds to scale accordingly
    const xValues = dataPoints.map(point => point.x);
    const yValues = dataPoints.map(point => point.y);
    const minX = Math.min(...xValues) - 1; // Extra padding
    const maxX = Math.max(...xValues) + 1;
    const minY = Math.min(...yValues) - 1;
    const maxY = Math.max(...yValues) + 1;
  
    // Scale factor for fitting points
    const scaleX = canvas.width / (maxX - minX);
    const scaleY = canvas.height / (maxY - minY);
  
    // Function to scale coordinates to canvas
    const scalePoint = (x, y) => ({
      x: (x - minX) * scaleX,
      y: canvas.height - (y - minY) * scaleY,
    });
  
    // Draw each data point and residual line
    ctx.fillStyle = 'red';
    dataPoints.forEach(point => {
      const scaledPoint = scalePoint(point.x, point.y);
      ctx.beginPath();
      ctx.arc(scaledPoint.x, scaledPoint.y, 5, 0, 2 * Math.PI);
      ctx.fill();
  
      // Calculate point on trend line for residual line
      const trendLineY = slope * point.x + intercept;
      const trendLinePoint = scalePoint(point.x, trendLineY);
  
      // Draw residual line
      ctx.strokeStyle = 'blue';
      ctx.lineWidth = 1;
      ctx.beginPath();
      ctx.moveTo(scaledPoint.x, scaledPoint.y);
      ctx.lineTo(trendLinePoint.x, trendLinePoint.y);
      ctx.stroke();
    });
  
    // Draw regression line
    ctx.strokeStyle = 'green';
    ctx.lineWidth = 2;
    ctx.beginPath();
    const startPoint = scalePoint(minX, slope * minX + intercept);
    const endPoint = scalePoint(maxX, slope * maxX + intercept);
    ctx.moveTo(startPoint.x, startPoint.y);
    ctx.lineTo(endPoint.x, endPoint.y);
    ctx.stroke();
  
    // Draw axis labels with padding
    ctx.fillStyle = 'black';
    ctx.font = '16px Arial';
    ctx.textAlign = 'center';
    ctx.fillText(xAxisLabel, canvas.width / 2, canvas.height - 20);
  
    ctx.save();
    ctx.translate(20, canvas.height / 2);
    ctx.rotate(-Math.PI / 2);
    ctx.fillText(yAxisLabel, 0, 0);
    ctx.restore();
  };

  useEffect(() => {
    drawLineAndResidualsOnCanvas();
  }, [dataPoints, slope, intercept]);

  useEffect(() => {
    drawLineAndResidualsOnCanvas();
  }, [dataPoints, slope, intercept]);
  const handleSlopeChange = (e) => setSlope(parseFloat(e.target.value));
  const handleInterceptChange = (e) => setIntercept(parseFloat(e.target.value));


  useEffect(() => {
    drawLineAndResidualsOnCanvas();
  }, [xAxisLabel, yAxisLabel]);

  

  return (
    <div>
      <center>
      <div className="text-container">
  <h2>Machine Learning Visualization</h2>
  <p>This tool visualizes the effect of changing the slope and intercept on the Mean Squared Error (MSE) cost function surface for a linear model.</p>
  
  <p><strong>Model:</strong> The linear model used here is <code>f<sub>w,b</sub>(x) = wx + b</code>, where <code>w</code> represents the slope, and <code>b</code> represents the intercept. By adjusting these parameters, we aim to find the line that best fits the given data points.</p>
  
  <p><strong>Goal:</strong> The goal is to minimize the cost function <code>J(w, b)</code> by adjusting <code>w</code> and <code>b</code>. The cost function, defined as the Mean Squared Error (MSE), quantifies the difference between the model's predictions and the actual data points. Lower values of <code>J(w, b)</code> indicate a better fit of the model to the data.</p>
  
  <p><strong>Gradient Descent:</strong> Gradient descent is an optimization algorithm used to minimize the cost function <code>J(w, b)</code>. It involves iteratively adjusting the parameters <code>w</code> and <code>b</code> by moving in the direction of the negative gradient (the steepest descent) of <code>J(w, b)</code>. The update equations are:</p>
  <ul>
    <li><code>w = w - α * (∂J(w, b) / ∂w)</code></li>
    <li><code>b = b - α * (∂J(w, b) / ∂b)</code></li>
  </ul>
  
  <p>In these equations, <code>α</code> (alpha) is the learning rate, which controls the size of each step. The partial derivatives, <code>∂J(w, b) / ∂w</code> and <code>∂J(w, b) / ∂b</code>, represent the gradients of the cost function with respect to <code>w</code> and <code>b</code>. By repeatedly applying these updates, the algorithm gradually converges to the values of <code>w</code> and <code>b</code> that minimize the cost function.</p>
  
  <p>Click the <strong>"Optimize with Gradient Descent"</strong> button to automatically find the optimal values of <code>w</code> and <code>b</code> for this dataset.</p>
</div>
      </center>
      <div style={{ display: 'flex', height: '100vh', backgroundColor: '#f0f0f5' }}>
        <div style={{ width: '60%', position: 'relative', margin: '20px', border: '2px solid #ccc', borderRadius: '10px', overflow: 'hidden', boxShadow: '0 4px 8px rgba(0, 0, 0, 0.1)' }}>
          <div id="ml-visualization-container" style={{ width: '100%', height: '100%', backgroundColor: 'white' }}></div>
        </div>
        <div style={{ width: '40%', padding: '20px', backgroundColor: '#fff', boxShadow: '0 4px 8px rgba(0, 0, 0, 0.1)' }}>
          <h3>2D Line Plot</h3>
          <canvas ref={canvasRef} width={400} height={400} style={{ border: '1px solid #ccc', marginBottom: '20px' }}></canvas>
          <div>
            <label>Slope (w): {slope}</label>
            <input
              type="range"
              min="-2"
              max="2"
              step="0.1"
              value={slope}
              onChange={handleSlopeChange}
              style={{ width: '100%' }}
            />
          </div>
          <div style={{ marginTop: '10px' }}>
            <label>Intercept (b): {intercept}</label>
            <input
              type="range"
              min="-5"
              max="5"
              step="0.1"
              value={intercept}
              onChange={handleInterceptChange}
              style={{ width: '100%' }}
            />
          </div>
          <button onClick={performGradientDescent}>Optimize with Gradient Descent</button>

          <h3>Add Data Point</h3>
          <label>
            X-coordinate:
            <input type="number" id="x-coord" step="0.1" />
          </label>
          <br />
          <label>
            Y-coordinate:
            <input type="number" id="y-coord" step="0.1" />
          </label>
          <br />
          
          <button onClick={() => {
            const x = document.getElementById('x-coord').value;
            const y = document.getElementById('y-coord').value;
            addDataPoint(x, y);
          }}>Add Point</button>
          <br />
          <label></label>
        
          <label>
  X-axis Label:
  <input
    type="text"
    value={xAxisLabel}
    onChange={(e) => setXAxisLabel(e.target.value)}
  />
</label>
<label>
  Y-axis Label:
  <input
    type="text"
    value={yAxisLabel}
    onChange={(e) => setYAxisLabel(e.target.value)}
  />
</label>
          <div style={{ marginTop: '20px', fontWeight: 'bold' }}>
            <p>Equation: f(x) = {slope.toFixed(2)}x + {intercept.toFixed(2)}</p>
        

          </div>
          <br/><br/>
        </div>
      </div>
    </div>
  );
};

export default MachineLearningVisualization;
